Merge branch 'develop' of github.com:matrix-org/synapse into neilj/mau_tracker

This commit is contained in:
Neil Johnson 2018-08-01 17:49:41 +01:00
commit d766f26de9
48 changed files with 626 additions and 280 deletions

View File

@ -62,4 +62,7 @@ Christoph Witzany <christoph at web.crofting.com>
* Add LDAP support for authentication * Add LDAP support for authentication
Pierre Jaury <pierre at jaury.eu> Pierre Jaury <pierre at jaury.eu>
* Docker packaging * Docker packaging
Serban Constantin <serban.constantin at gmail dot com>
* Small bug fix

View File

@ -1,16 +1,32 @@
FROM docker.io/python:2-alpine3.7 FROM docker.io/python:2-alpine3.7
RUN apk add --no-cache --virtual .nacl_deps su-exec build-base libffi-dev zlib-dev libressl-dev libjpeg-turbo-dev linux-headers postgresql-dev libxslt-dev RUN apk add --no-cache --virtual .nacl_deps \
build-base \
libffi-dev \
libjpeg-turbo-dev \
libressl-dev \
libxslt-dev \
linux-headers \
postgresql-dev \
su-exec \
zlib-dev
COPY . /synapse COPY . /synapse
# A wheel cache may be provided in ./cache for faster build # A wheel cache may be provided in ./cache for faster build
RUN cd /synapse \ RUN cd /synapse \
&& pip install --upgrade pip setuptools psycopg2 lxml \ && pip install --upgrade \
lxml \
pip \
psycopg2 \
setuptools \
&& mkdir -p /synapse/cache \ && mkdir -p /synapse/cache \
&& pip install -f /synapse/cache --upgrade --process-dependency-links . \ && pip install -f /synapse/cache --upgrade --process-dependency-links . \
&& mv /synapse/contrib/docker/start.py /synapse/contrib/docker/conf / \ && mv /synapse/contrib/docker/start.py /synapse/contrib/docker/conf / \
&& rm -rf setup.py setup.cfg synapse && rm -rf \
setup.cfg \
setup.py \
synapse
VOLUME ["/data"] VOLUME ["/data"]

1
changelog.d/2952.bugfix Normal file
View File

@ -0,0 +1 @@
Make /directory/list API return 404 for room not found instead of 400

1
changelog.d/3384.misc Normal file
View File

@ -0,0 +1 @@
Rewrite cache list decorator

1
changelog.d/3543.misc Normal file
View File

@ -0,0 +1 @@
Improve Dockerfile and docker-compose instructions

1
changelog.d/3569.bugfix Normal file
View File

@ -0,0 +1 @@
Unicode passwords are now normalised before hashing, preventing the instance where two different devices or browsers might send a different UTF-8 sequence for the password.

1
changelog.d/3612.misc Normal file
View File

@ -0,0 +1 @@
Make EventStore inherit from EventFederationStore

1
changelog.d/3628.misc Normal file
View File

@ -0,0 +1 @@
Remove unused field "pdu_failures" from transactions.

1
changelog.d/3630.feature Normal file
View File

@ -0,0 +1 @@
Add ability to limit number of monthly active users on the server

1
changelog.d/3634.misc Normal file
View File

@ -0,0 +1 @@
rename replication_layer to federation_client

View File

@ -9,13 +9,7 @@ use that server.
## Build ## Build
Build the docker image with the `docker build` command from the root of the synapse repository. Build the docker image with the `docker-compose build` command.
```
docker build -t docker.io/matrixdotorg/synapse .
```
The `-t` option sets the image tag. Official images are tagged `matrixdotorg/synapse:<version>` where `<version>` is the same as the release tag in the synapse git repository.
You may have a local Python wheel cache available, in which case copy the relevant packages in the ``cache/`` directory at the root of the project. You may have a local Python wheel cache available, in which case copy the relevant packages in the ``cache/`` directory at the root of the project.

View File

@ -6,6 +6,7 @@ version: '3'
services: services:
synapse: synapse:
build: ../..
image: docker.io/matrixdotorg/synapse:latest image: docker.io/matrixdotorg/synapse:latest
# Since snyapse does not retry to connect to the database, restart upon # Since snyapse does not retry to connect to the database, restart upon
# failure # failure

View File

@ -252,10 +252,10 @@ class Auth(object):
if ip_address not in app_service.ip_range_whitelist: if ip_address not in app_service.ip_range_whitelist:
defer.returnValue((None, None)) defer.returnValue((None, None))
if "user_id" not in request.args: if b"user_id" not in request.args:
defer.returnValue((app_service.sender, app_service)) defer.returnValue((app_service.sender, app_service))
user_id = request.args["user_id"][0] user_id = request.args[b"user_id"][0].decode('utf8')
if app_service.sender == user_id: if app_service.sender == user_id:
defer.returnValue((app_service.sender, app_service)) defer.returnValue((app_service.sender, app_service))

View File

@ -55,6 +55,7 @@ class Codes(object):
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):

View File

@ -20,6 +20,8 @@ import sys
from six import iteritems from six import iteritems
from prometheus_client import Gauge
from twisted.application import service from twisted.application import service
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.web.resource import EncodingResourceWrapper, NoResource from twisted.web.resource import EncodingResourceWrapper, NoResource
@ -301,6 +303,11 @@ class SynapseHomeServer(HomeServer):
quit_with_error(e.message) quit_with_error(e.message)
# Gauges to expose monthly active user control metrics
current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU")
max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit")
def setup(config_options): def setup(config_options):
""" """
Args: Args:
@ -516,6 +523,18 @@ def run(hs):
MonthlyActiveUsersStore(hs).reap_monthly_active_users, 1000 * 60 * 60 MonthlyActiveUsersStore(hs).reap_monthly_active_users, 1000 * 60 * 60
) )
@defer.inlineCallbacks
def generate_monthly_active_users():
count = 0
if hs.config.limit_usage_by_mau:
count = yield hs.get_datastore().count_monthly_users()
current_mau_gauge.set(float(count))
max_mau_value_gauge.set(float(hs.config.max_mau_value))
generate_monthly_active_users()
if hs.config.limit_usage_by_mau:
clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
if hs.config.report_stats: if hs.config.report_stats:
logger.info("Scheduling stats reporting for 3 hour intervals") logger.info("Scheduling stats reporting for 3 hour intervals")
clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000) clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)

View File

@ -67,6 +67,14 @@ class ServerConfig(Config):
"block_non_admin_invites", False, "block_non_admin_invites", False,
) )
# Options to control access by tracking MAU
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
if self.limit_usage_by_mau:
self.max_mau_value = config.get(
"max_mau_value", 0,
)
else:
self.max_mau_value = 0
# FIXME: federation_domain_whitelist needs sytests # FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None self.federation_domain_whitelist = None
federation_domain_whitelist = config.get( federation_domain_whitelist = config.get(

View File

@ -207,10 +207,6 @@ class FederationServer(FederationBase):
edu.content edu.content
) )
pdu_failures = getattr(transaction, "pdu_failures", [])
for fail in pdu_failures:
logger.info("Got failure %r", fail)
response = { response = {
"pdus": pdu_results, "pdus": pdu_results,
} }

View File

@ -62,8 +62,6 @@ class FederationRemoteSendQueue(object):
self.edus = SortedDict() # stream position -> Edu self.edus = SortedDict() # stream position -> Edu
self.failures = SortedDict() # stream position -> (destination, Failure)
self.device_messages = SortedDict() # stream position -> destination self.device_messages = SortedDict() # stream position -> destination
self.pos = 1 self.pos = 1
@ -79,7 +77,7 @@ class FederationRemoteSendQueue(object):
for queue_name in [ for queue_name in [
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
"edus", "failures", "device_messages", "pos_time", "edus", "device_messages", "pos_time",
]: ]:
register(queue_name, getattr(self, queue_name)) register(queue_name, getattr(self, queue_name))
@ -149,12 +147,6 @@ class FederationRemoteSendQueue(object):
for key in keys[:i]: for key in keys[:i]:
del self.edus[key] del self.edus[key]
# Delete things out of failure map
keys = self.failures.keys()
i = self.failures.bisect_left(position_to_delete)
for key in keys[:i]:
del self.failures[key]
# Delete things out of device map # Delete things out of device map
keys = self.device_messages.keys() keys = self.device_messages.keys()
i = self.device_messages.bisect_left(position_to_delete) i = self.device_messages.bisect_left(position_to_delete)
@ -204,13 +196,6 @@ class FederationRemoteSendQueue(object):
self.notifier.on_new_replication_data() self.notifier.on_new_replication_data()
def send_failure(self, failure, destination):
"""As per TransactionQueue"""
pos = self._next_pos()
self.failures[pos] = (destination, str(failure))
self.notifier.on_new_replication_data()
def send_device_messages(self, destination): def send_device_messages(self, destination):
"""As per TransactionQueue""" """As per TransactionQueue"""
pos = self._next_pos() pos = self._next_pos()
@ -285,17 +270,6 @@ class FederationRemoteSendQueue(object):
for (pos, edu) in edus: for (pos, edu) in edus:
rows.append((pos, EduRow(edu))) rows.append((pos, EduRow(edu)))
# Fetch changed failures
i = self.failures.bisect_right(from_token)
j = self.failures.bisect_right(to_token) + 1
failures = self.failures.items()[i:j]
for (pos, (destination, failure)) in failures:
rows.append((pos, FailureRow(
destination=destination,
failure=failure,
)))
# Fetch changed device messages # Fetch changed device messages
i = self.device_messages.bisect_right(from_token) i = self.device_messages.bisect_right(from_token)
j = self.device_messages.bisect_right(to_token) + 1 j = self.device_messages.bisect_right(to_token) + 1
@ -417,34 +391,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
buff.edus.setdefault(self.edu.destination, []).append(self.edu) buff.edus.setdefault(self.edu.destination, []).append(self.edu)
class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
"destination", # str
"failure",
))):
"""Streams failures to a remote server. Failures are issued when there was
something wrong with a transaction the remote sent us, e.g. it included
an event that was invalid.
"""
TypeId = "f"
@staticmethod
def from_data(data):
return FailureRow(
destination=data["destination"],
failure=data["failure"],
)
def to_data(self):
return {
"destination": self.destination,
"failure": self.failure,
}
def add_to_buffer(self, buff):
buff.failures.setdefault(self.destination, []).append(self.failure)
class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
"destination", # str "destination", # str
))): ))):
@ -471,7 +417,6 @@ TypeToRow = {
PresenceRow, PresenceRow,
KeyedEduRow, KeyedEduRow,
EduRow, EduRow,
FailureRow,
DeviceRow, DeviceRow,
) )
} }
@ -481,7 +426,6 @@ ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
"presence", # list(UserPresenceState) "presence", # list(UserPresenceState)
"keyed_edus", # dict of destination -> { key -> Edu } "keyed_edus", # dict of destination -> { key -> Edu }
"edus", # dict of destination -> [Edu] "edus", # dict of destination -> [Edu]
"failures", # dict of destination -> [failures]
"device_destinations", # set of destinations "device_destinations", # set of destinations
)) ))
@ -503,7 +447,6 @@ def process_rows_for_federation(transaction_queue, rows):
presence=[], presence=[],
keyed_edus={}, keyed_edus={},
edus={}, edus={},
failures={},
device_destinations=set(), device_destinations=set(),
) )
@ -532,9 +475,5 @@ def process_rows_for_federation(transaction_queue, rows):
edu.destination, edu.edu_type, edu.content, key=None, edu.destination, edu.edu_type, edu.content, key=None,
) )
for destination, failure_list in iteritems(buff.failures):
for failure in failure_list:
transaction_queue.send_failure(destination, failure)
for destination in buff.device_destinations: for destination in buff.device_destinations:
transaction_queue.send_device_messages(destination) transaction_queue.send_device_messages(destination)

View File

@ -116,9 +116,6 @@ class TransactionQueue(object):
), ),
) )
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
# destination -> stream_id of last successfully sent to-device message. # destination -> stream_id of last successfully sent to-device message.
# NB: may be a long or an int. # NB: may be a long or an int.
self.last_device_stream_id_by_dest = {} self.last_device_stream_id_by_dest = {}
@ -382,19 +379,6 @@ class TransactionQueue(object):
self._attempt_new_transaction(destination) self._attempt_new_transaction(destination)
def send_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
return
if not self.can_send_to(destination):
return
self.pending_failures_by_dest.setdefault(
destination, []
).append(failure)
self._attempt_new_transaction(destination)
def send_device_messages(self, destination): def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost": if destination == self.server_name or destination == "localhost":
return return
@ -469,7 +453,6 @@ class TransactionQueue(object):
pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {}) pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_edus.extend( pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values() self.pending_edus_keyed_by_dest.pop(destination, {}).values()
@ -497,7 +480,7 @@ class TransactionQueue(object):
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus)) destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures: if not pending_pdus and not pending_edus:
logger.debug("TX [%s] Nothing to send", destination) logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = ( self.last_device_stream_id_by_dest[destination] = (
device_stream_id device_stream_id
@ -507,7 +490,7 @@ class TransactionQueue(object):
# END CRITICAL SECTION # END CRITICAL SECTION
success = yield self._send_new_transaction( success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures, destination, pending_pdus, pending_edus,
) )
if success: if success:
sent_transactions_counter.inc() sent_transactions_counter.inc()
@ -584,14 +567,12 @@ class TransactionQueue(object):
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus, def _send_new_transaction(self, destination, pending_pdus, pending_edus):
pending_failures):
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[1]) pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus] pdus = [x[0] for x in pending_pdus]
edus = pending_edus edus = pending_edus
failures = [x.get_dict() for x in pending_failures]
success = True success = True
@ -601,11 +582,10 @@ class TransactionQueue(object):
logger.debug( logger.debug(
"TX [%s] {%s} Attempting new transaction" "TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)", " (pdus: %d, edus: %d)",
destination, txn_id, destination, txn_id,
len(pdus), len(pdus),
len(edus), len(edus),
len(failures)
) )
logger.debug("TX [%s] Persisting transaction...", destination) logger.debug("TX [%s] Persisting transaction...", destination)
@ -617,7 +597,6 @@ class TransactionQueue(object):
destination=destination, destination=destination,
pdus=pdus, pdus=pdus,
edus=edus, edus=edus,
pdu_failures=failures,
) )
self._next_txn_id += 1 self._next_txn_id += 1
@ -627,12 +606,11 @@ class TransactionQueue(object):
logger.debug("TX [%s] Persisted transaction", destination) logger.debug("TX [%s] Persisted transaction", destination)
logger.info( logger.info(
"TX [%s] {%s} Sending transaction [%s]," "TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)", " (PDUs: %d, EDUs: %d)",
destination, txn_id, destination, txn_id,
transaction.transaction_id, transaction.transaction_id,
len(pdus), len(pdus),
len(edus), len(edus),
len(failures),
) )
# Actually send the transaction # Actually send the transaction

View File

@ -165,7 +165,7 @@ def _parse_auth_header(header_bytes):
param_dict = dict(kv.split("=") for kv in params) param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value): def strip_quotes(value):
if value.startswith(b"\""): if value.startswith("\""):
return value[1:-1] return value[1:-1]
else: else:
return value return value
@ -283,11 +283,10 @@ class FederationSendServlet(BaseFederationServlet):
) )
logger.info( logger.info(
"Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)", "Received txn %s from %s. (PDUs: %d, EDUs: %d)",
transaction_id, origin, transaction_id, origin,
len(transaction_data.get("pdus", [])), len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])), len(transaction_data.get("edus", [])),
len(transaction_data.get("failures", [])),
) )
# We should ideally be getting this from the security layer. # We should ideally be getting this from the security layer.

View File

@ -73,7 +73,6 @@ class Transaction(JsonEncodedObject):
"previous_ids", "previous_ids",
"pdus", "pdus",
"edus", "edus",
"pdu_failures",
] ]
internal_keys = [ internal_keys = [

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import unicodedata
import attr import attr
import bcrypt import bcrypt
@ -519,6 +520,7 @@ class AuthHandler(BaseHandler):
""" """
logger.info("Logging in user %s on device %s", user_id, device_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id)
yield self._check_mau_limits()
# the device *should* have been registered before we got here; however, # the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we # it's possible we raced against a DELETE operation. The thing we
@ -626,6 +628,7 @@ class AuthHandler(BaseHandler):
# special case to check for "password" for the check_password interface # special case to check for "password" for the check_password interface
# for the auth providers # for the auth providers
password = login_submission.get("password") password = login_submission.get("password")
if login_type == LoginType.PASSWORD: if login_type == LoginType.PASSWORD:
if not self._password_enabled: if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.") raise SynapseError(400, "Password login has been disabled.")
@ -707,9 +710,10 @@ class AuthHandler(BaseHandler):
multiple inexact matches. multiple inexact matches.
Args: Args:
user_id (str): complete @user:id user_id (unicode): complete @user:id
password (unicode): the provided password
Returns: Returns:
(str) the canonical_user_id, or None if unknown user / bad password (unicode) the canonical_user_id, or None if unknown user / bad password
""" """
lookupres = yield self._find_user_id_and_pwd_hash(user_id) lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres: if not lookupres:
@ -728,15 +732,18 @@ class AuthHandler(BaseHandler):
device_id) device_id)
defer.returnValue(access_token) defer.returnValue(access_token)
@defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
yield self._check_mau_limits()
auth_api = self.hs.get_auth() auth_api = self.hs.get_auth()
user_id = None
try: try:
macaroon = pymacaroons.Macaroon.deserialize(login_token) macaroon = pymacaroons.Macaroon.deserialize(login_token)
user_id = auth_api.get_user_id_from_macaroon(macaroon) user_id = auth_api.get_user_id_from_macaroon(macaroon)
auth_api.validate_macaroon(macaroon, "login", True, user_id) auth_api.validate_macaroon(macaroon, "login", True, user_id)
return user_id
except Exception: except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_access_token(self, access_token): def delete_access_token(self, access_token):
@ -849,14 +856,19 @@ class AuthHandler(BaseHandler):
"""Computes a secure hash of password. """Computes a secure hash of password.
Args: Args:
password (str): Password to hash. password (unicode): Password to hash.
Returns: Returns:
Deferred(str): Hashed password. Deferred(unicode): Hashed password.
""" """
def _do_hash(): def _do_hash():
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, # Normalise the Unicode in the password
bcrypt.gensalt(self.bcrypt_rounds)) pw = unicodedata.normalize("NFKC", password)
return bcrypt.hashpw(
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
bcrypt.gensalt(self.bcrypt_rounds),
).decode('ascii')
return make_deferred_yieldable( return make_deferred_yieldable(
threads.deferToThreadPool( threads.deferToThreadPool(
@ -868,16 +880,19 @@ class AuthHandler(BaseHandler):
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.
Args: Args:
password (str): Password to hash. password (unicode): Password to hash.
stored_hash (str): Expected hash value. stored_hash (unicode): Expected hash value.
Returns: Returns:
Deferred(bool): Whether self.hash(password) == stored_hash. Deferred(bool): Whether self.hash(password) == stored_hash.
""" """
def _do_validate_hash(): def _do_validate_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw( return bcrypt.checkpw(
password.encode('utf8') + self.hs.config.password_pepper, pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
stored_hash.encode('utf8') stored_hash.encode('utf8')
) )
@ -892,6 +907,19 @@ class AuthHandler(BaseHandler):
else: else:
return defer.succeed(False) return defer.succeed(False)
@defer.inlineCallbacks
def _check_mau_limits(self):
"""
Ensure that if mau blocking is enabled that invalid users cannot
log in.
"""
if self.hs.config.limit_usage_by_mau is True:
current_mau = yield self.store.count_monthly_users()
if current_mau >= self.hs.config.max_mau_value:
raise AuthError(
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
)
@attr.s @attr.s
class MacaroonGenerator(object): class MacaroonGenerator(object):

View File

@ -76,7 +76,7 @@ class FederationHandler(BaseHandler):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.replication_layer = hs.get_federation_client() self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@ -255,7 +255,7 @@ class FederationHandler(BaseHandler):
# know about # know about
for p in prevs - seen: for p in prevs - seen:
state, got_auth_chain = ( state, got_auth_chain = (
yield self.replication_layer.get_state_for_room( yield self.federation_client.get_state_for_room(
origin, pdu.room_id, p origin, pdu.room_id, p
) )
) )
@ -338,7 +338,7 @@ class FederationHandler(BaseHandler):
# #
# see https://github.com/matrix-org/synapse/pull/1744 # see https://github.com/matrix-org/synapse/pull/1744
missing_events = yield self.replication_layer.get_missing_events( missing_events = yield self.federation_client.get_missing_events(
origin, origin,
pdu.room_id, pdu.room_id,
earliest_events_ids=list(latest), earliest_events_ids=list(latest),
@ -522,7 +522,7 @@ class FederationHandler(BaseHandler):
if dest == self.server_name: if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.") raise SynapseError(400, "Can't backfill from self.")
events = yield self.replication_layer.backfill( events = yield self.federation_client.backfill(
dest, dest,
room_id, room_id,
limit=limit, limit=limit,
@ -570,7 +570,7 @@ class FederationHandler(BaseHandler):
state_events = {} state_events = {}
events_to_state = {} events_to_state = {}
for e_id in edges: for e_id in edges:
state, auth = yield self.replication_layer.get_state_for_room( state, auth = yield self.federation_client.get_state_for_room(
destination=dest, destination=dest,
room_id=room_id, room_id=room_id,
event_id=e_id event_id=e_id
@ -612,7 +612,7 @@ class FederationHandler(BaseHandler):
results = yield logcontext.make_deferred_yieldable(defer.gatherResults( results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
logcontext.run_in_background( logcontext.run_in_background(
self.replication_layer.get_pdu, self.federation_client.get_pdu,
[dest], [dest],
event_id, event_id,
outlier=True, outlier=True,
@ -893,7 +893,7 @@ class FederationHandler(BaseHandler):
Invites must be signed by the invitee's server before distribution. Invites must be signed by the invitee's server before distribution.
""" """
pdu = yield self.replication_layer.send_invite( pdu = yield self.federation_client.send_invite(
destination=target_host, destination=target_host,
room_id=event.room_id, room_id=event.room_id,
event_id=event.event_id, event_id=event.event_id,
@ -955,7 +955,7 @@ class FederationHandler(BaseHandler):
target_hosts.insert(0, origin) target_hosts.insert(0, origin)
except ValueError: except ValueError:
pass pass
ret = yield self.replication_layer.send_join(target_hosts, event) ret = yield self.federation_client.send_join(target_hosts, event)
origin = ret["origin"] origin = ret["origin"]
state = ret["state"] state = ret["state"]
@ -1211,7 +1211,7 @@ class FederationHandler(BaseHandler):
except ValueError: except ValueError:
pass pass
yield self.replication_layer.send_leave( yield self.federation_client.send_leave(
target_hosts, target_hosts,
event event
) )
@ -1234,7 +1234,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
content={},): content={},):
origin, pdu = yield self.replication_layer.make_membership_event( origin, pdu = yield self.federation_client.make_membership_event(
target_hosts, target_hosts,
room_id, room_id,
user_id, user_id,
@ -1567,7 +1567,7 @@ class FederationHandler(BaseHandler):
missing_auth_events.add(e_id) missing_auth_events.add(e_id)
for e_id in missing_auth_events: for e_id in missing_auth_events:
m_ev = yield self.replication_layer.get_pdu( m_ev = yield self.federation_client.get_pdu(
[origin], [origin],
e_id, e_id,
outlier=True, outlier=True,
@ -1777,7 +1777,7 @@ class FederationHandler(BaseHandler):
logger.info("Missing auth: %s", missing_auth) logger.info("Missing auth: %s", missing_auth)
# If we don't have all the auth events, we need to get them. # If we don't have all the auth events, we need to get them.
try: try:
remote_auth_chain = yield self.replication_layer.get_event_auth( remote_auth_chain = yield self.federation_client.get_event_auth(
origin, event.room_id, event.event_id origin, event.room_id, event.event_id
) )
@ -1893,7 +1893,7 @@ class FederationHandler(BaseHandler):
try: try:
# 2. Get remote difference. # 2. Get remote difference.
result = yield self.replication_layer.query_auth( result = yield self.federation_client.query_auth(
origin, origin,
event.room_id, event.room_id,
event.event_id, event.event_id,
@ -2192,7 +2192,7 @@ class FederationHandler(BaseHandler):
yield member_handler.send_membership_event(None, event, context) yield member_handler.send_membership_event(None, event, context)
else: else:
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
yield self.replication_layer.forward_third_party_invite( yield self.federation_client.forward_third_party_invite(
destinations, destinations,
room_id, room_id,
event_dict, event_dict,

View File

@ -45,7 +45,7 @@ class RegistrationHandler(BaseHandler):
hs (synapse.server.HomeServer): hs (synapse.server.HomeServer):
""" """
super(RegistrationHandler, self).__init__(hs) super(RegistrationHandler, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
@ -131,7 +131,7 @@ class RegistrationHandler(BaseHandler):
Args: Args:
localpart : The local part of the user ID to register. If None, localpart : The local part of the user ID to register. If None,
one will be generated. one will be generated.
password (str) : The password to assign to this user so they can password (unicode) : The password to assign to this user so they can
login again. This can be None which means they cannot login again login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user). via a password (e.g. the user is an application service user).
generate_token (bool): Whether a new access token should be generate_token (bool): Whether a new access token should be
@ -144,6 +144,7 @@ class RegistrationHandler(BaseHandler):
Raises: Raises:
RegistrationError if there was a problem registering. RegistrationError if there was a problem registering.
""" """
yield self._check_mau_limits()
password_hash = None password_hash = None
if password: if password:
password_hash = yield self.auth_handler().hash(password) password_hash = yield self.auth_handler().hash(password)
@ -288,6 +289,7 @@ class RegistrationHandler(BaseHandler):
400, 400,
"User ID can only contain characters a-z, 0-9, or '=_-./'", "User ID can only contain characters a-z, 0-9, or '=_-./'",
) )
yield self._check_mau_limits()
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@ -437,7 +439,7 @@ class RegistrationHandler(BaseHandler):
""" """
if localpart is None: if localpart is None:
raise SynapseError(400, "Request must include user id") raise SynapseError(400, "Request must include user id")
yield self._check_mau_limits()
need_register = True need_register = True
try: try:
@ -531,3 +533,16 @@ class RegistrationHandler(BaseHandler):
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
action="join", action="join",
) )
@defer.inlineCallbacks
def _check_mau_limits(self):
"""
Do not accept registrations if monthly active user limits exceeded
and limiting is enabled
"""
if self.hs.config.limit_usage_by_mau is True:
current_mau = yield self.store.count_monthly_users()
if current_mau >= self.hs.config.max_mau_value:
raise RegistrationError(
403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED
)

View File

@ -13,12 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import cgi import cgi
import collections import collections
import logging import logging
import urllib
from six.moves import http_client from six import PY3
from six.moves import http_client, urllib
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
@ -264,6 +265,7 @@ class JsonResource(HttpServer, resource.Resource):
self.hs = hs self.hs = hs
def register_paths(self, method, path_patterns, callback): def register_paths(self, method, path_patterns, callback):
method = method.encode("utf-8") # method is bytes on py3
for path_pattern in path_patterns: for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern) logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append( self.path_regexs.setdefault(method, []).append(
@ -296,8 +298,19 @@ class JsonResource(HttpServer, resource.Resource):
# here. If it throws an exception, that is handled by the wrapper # here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler. # installed by @request_handler.
def _unquote(s):
if PY3:
# On Python 3, unquote is unicode -> unicode
return urllib.parse.unquote(s)
else:
# On Python 2, unquote is bytes -> bytes We need to encode the
# URL again (as it was decoded by _get_handler_for request), as
# ASCII because it's a URL, and then decode it to get the UTF-8
# characters that were quoted.
return urllib.parse.unquote(s.encode('ascii')).decode('utf8')
kwargs = intern_dict({ kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value name: _unquote(value) if value else value
for name, value in group_dict.items() for name, value in group_dict.items()
}) })
@ -313,9 +326,9 @@ class JsonResource(HttpServer, resource.Resource):
request (twisted.web.http.Request): request (twisted.web.http.Request):
Returns: Returns:
Tuple[Callable, dict[str, str]]: callback method, and the dict Tuple[Callable, dict[unicode, unicode]]: callback method, and the
mapping keys to path components as specified in the handler's dict mapping keys to path components as specified in the
path match regexp. handler's path match regexp.
The callback will normally be a method registered via The callback will normally be a method registered via
register_paths, so will return (possibly via Deferred) either register_paths, so will return (possibly via Deferred) either
@ -327,7 +340,7 @@ class JsonResource(HttpServer, resource.Resource):
# Loop through all the registered callbacks to check if the method # Loop through all the registered callbacks to check if the method
# and path regex match # and path regex match
for path_entry in self.path_regexs.get(request.method, []): for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path) m = path_entry.pattern.match(request.path.decode('ascii'))
if m: if m:
# We found a match! # We found a match!
return path_entry.callback, m.groupdict() return path_entry.callback, m.groupdict()
@ -383,7 +396,7 @@ class RootRedirect(resource.Resource):
self.url = path self.url = path
def render_GET(self, request): def render_GET(self, request):
return redirectTo(self.url, request) return redirectTo(self.url.encode('ascii'), request)
def getChild(self, name, request): def getChild(self, name, request):
if len(name) == 0: if len(name) == 0:
@ -404,12 +417,14 @@ def respond_with_json(request, code, json_object, send_cors=False,
return return
if pretty_print: if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n" json_bytes = (encode_pretty_printed_json(json_object) + "\n"
).encode("utf-8")
else: else:
if canonical_json or synapse.events.USE_FROZEN_DICTS: if canonical_json or synapse.events.USE_FROZEN_DICTS:
# canonicaljson already encodes to bytes
json_bytes = encode_canonical_json(json_object) json_bytes = encode_canonical_json(json_object)
else: else:
json_bytes = json.dumps(json_object) json_bytes = json.dumps(json_object).encode("utf-8")
return respond_with_json_bytes( return respond_with_json_bytes(
request, code, json_bytes, request, code, json_bytes,

View File

@ -171,8 +171,16 @@ def parse_json_value_from_request(request, allow_empty_body=False):
if not content_bytes and allow_empty_body: if not content_bytes and allow_empty_body:
return None return None
# Decode to Unicode so that simplejson will return Unicode strings on
# Python 2
try: try:
content = json.loads(content_bytes) content_unicode = content_bytes.decode('utf8')
except UnicodeDecodeError:
logger.warn("Unable to decode UTF-8")
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
try:
content = json.loads(content_unicode)
except Exception as e: except Exception as e:
logger.warn("Unable to parse JSON: %s", e) logger.warn("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)

View File

@ -18,6 +18,7 @@ import hashlib
import hmac import hmac
import logging import logging
from six import text_type
from six.moves import http_client from six.moves import http_client
from twisted.internet import defer from twisted.internet import defer
@ -131,7 +132,10 @@ class UserRegisterServlet(ClientV1RestServlet):
400, "username must be specified", errcode=Codes.BAD_JSON, 400, "username must be specified", errcode=Codes.BAD_JSON,
) )
else: else:
if (not isinstance(body['username'], str) or len(body['username']) > 512): if (
not isinstance(body['username'], text_type)
or len(body['username']) > 512
):
raise SynapseError(400, "Invalid username") raise SynapseError(400, "Invalid username")
username = body["username"].encode("utf-8") username = body["username"].encode("utf-8")
@ -143,7 +147,10 @@ class UserRegisterServlet(ClientV1RestServlet):
400, "password must be specified", errcode=Codes.BAD_JSON, 400, "password must be specified", errcode=Codes.BAD_JSON,
) )
else: else:
if (not isinstance(body['password'], str) or len(body['password']) > 512): if (
not isinstance(body['password'], text_type)
or len(body['password']) > 512
):
raise SynapseError(400, "Invalid password") raise SynapseError(400, "Invalid password")
password = body["password"].encode("utf-8") password = body["password"].encode("utf-8")
@ -166,17 +173,18 @@ class UserRegisterServlet(ClientV1RestServlet):
want_mac.update(b"admin" if admin else b"notadmin") want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest() want_mac = want_mac.hexdigest()
if not hmac.compare_digest(want_mac, got_mac): if not hmac.compare_digest(want_mac, got_mac.encode('ascii')):
raise SynapseError( raise SynapseError(403, "HMAC incorrect")
403, "HMAC incorrect",
)
# Reuse the parts of RegisterRestServlet to reduce code duplication # Reuse the parts of RegisterRestServlet to reduce code duplication
from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet
register = RegisterRestServlet(self.hs) register = RegisterRestServlet(self.hs)
(user_id, _) = yield register.registration_handler.register( (user_id, _) = yield register.registration_handler.register(
localpart=username.lower(), password=password, admin=bool(admin), localpart=body['username'].lower(),
password=body["password"],
admin=bool(admin),
generate_token=False, generate_token=False,
) )

View File

@ -18,7 +18,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.types import RoomAlias from synapse.types import RoomAlias
@ -159,7 +159,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
room = yield self.store.get_room(room_id) room = yield self.store.get_room(room_id)
if room is None: if room is None:
raise SynapseError(400, "Unknown room") raise NotFoundError("Unknown room")
defer.returnValue((200, { defer.returnValue((200, {
"visibility": "public" if room["is_public"] else "private" "visibility": "public" if room["is_public"] else "private"

View File

@ -193,15 +193,15 @@ class RegisterRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
kind = "user" kind = b"user"
if "kind" in request.args: if b"kind" in request.args:
kind = request.args["kind"][0] kind = request.args[b"kind"][0]
if kind == "guest": if kind == b"guest":
ret = yield self._do_guest_registration(body) ret = yield self._do_guest_registration(body)
defer.returnValue(ret) defer.returnValue(ret)
return return
elif kind != "user": elif kind != b"user":
raise UnrecognizedRequestError( raise UnrecognizedRequestError(
"Do not understand membership kind: %s" % (kind,) "Do not understand membership kind: %s" % (kind,)
) )
@ -389,8 +389,8 @@ class RegisterRestServlet(RestServlet):
assert_params_in_dict(params, ["password"]) assert_params_in_dict(params, ["password"])
desired_username = params.get("username", None) desired_username = params.get("username", None)
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None) guest_access_token = params.get("guest_access_token", None)
new_password = params.get("password", None)
if desired_username is not None: if desired_username is not None:
desired_username = desired_username.lower() desired_username = desired_username.lower()

View File

@ -177,7 +177,7 @@ class MediaStorage(object):
if res: if res:
with res: with res:
consumer = BackgroundFileConsumer( consumer = BackgroundFileConsumer(
open(local_path, "w"), self.hs.get_reactor()) open(local_path, "wb"), self.hs.get_reactor())
yield res.write_to_consumer(consumer) yield res.write_to_consumer(consumer)
yield consumer.wait() yield consumer.wait()
defer.returnValue(local_path) defer.returnValue(local_path)

View File

@ -577,7 +577,7 @@ def _make_state_cache_entry(
def _ordered_events(events): def _ordered_events(events):
def key_func(e): def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest() return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
return sorted(events, key=key_func) return sorted(events, key=key_func)

View File

@ -66,6 +66,7 @@ class DataStore(RoomMemberStore, RoomStore,
PresenceStore, TransactionStore, PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore, DirectoryStore, KeyStore, StateStore, SignatureStore,
ApplicationServiceStore, ApplicationServiceStore,
EventsStore,
EventFederationStore, EventFederationStore,
MediaRepositoryStore, MediaRepositoryStore,
RejectionsStore, RejectionsStore,
@ -73,7 +74,6 @@ class DataStore(RoomMemberStore, RoomStore,
PusherStore, PusherStore,
PushRuleStore, PushRuleStore,
ApplicationServiceTransactionStore, ApplicationServiceTransactionStore,
EventsStore,
ReceiptsStore, ReceiptsStore,
EndToEndKeyStore, EndToEndKeyStore,
SearchStore, SearchStore,
@ -94,6 +94,7 @@ class DataStore(RoomMemberStore, RoomStore,
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
self.db_conn = db_conn
self._stream_id_gen = StreamIdGenerator( self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", db_conn, "events", "stream_ordering",
extra_tables=[("local_invites", "stream_id")] extra_tables=[("local_invites", "stream_id")]
@ -266,6 +267,31 @@ class DataStore(RoomMemberStore, RoomStore,
return self.runInteraction("count_users", _count_users) return self.runInteraction("count_users", _count_users)
def count_monthly_users(self):
"""Counts the number of users who used this homeserver in the last 30 days
This method should be refactored with count_daily_users - the only
reason not to is waiting on definition of mau
Returns:
Defered[int]
"""
def _count_monthly_users(txn):
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
sql = """
SELECT COALESCE(count(*), 0) FROM (
SELECT user_id FROM user_ips
WHERE last_seen > ?
GROUP BY user_id
) u
"""
txn.execute(sql, (thirty_days_ago,))
count, = txn.fetchone()
return count
return self.runInteraction("count_monthly_users", _count_monthly_users)
def count_r30_users(self): def count_r30_users(self):
""" """
Counts the number of 30 day retained users, defined as:- Counts the number of 30 day retained users, defined as:-

View File

@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.appservice import AppServiceTransaction from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices from synapse.config.appservice import load_appservices
from synapse.storage.events import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from ._base import SQLBaseStore from ._base import SQLBaseStore

View File

@ -25,7 +25,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.events import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.signatures import SignatureWorkerStore from synapse.storage.signatures import SignatureWorkerStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached

View File

@ -34,6 +34,8 @@ from synapse.api.errors import SynapseError
from synapse.events import EventBase # noqa: F401 from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
@ -65,7 +67,13 @@ state_delta_reuse_delta_counter = Counter(
def encode_json(json_object): def encode_json(json_object):
return frozendict_json_encoder.encode(json_object) """
Encode a Python object as JSON and return it in a Unicode string.
"""
out = frozendict_json_encoder.encode(json_object)
if isinstance(out, bytes):
out = out.decode('utf8')
return out
class _EventPeristenceQueue(object): class _EventPeristenceQueue(object):
@ -193,7 +201,9 @@ def _retry_on_integrity_error(func):
return f return f
class EventsStore(EventsWorkerStore): # inherits from EventFederationStore so that we can call _update_backward_extremities
# and _handle_mult_prev_events (though arguably those could both be moved in here)
class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
@ -1054,7 +1064,7 @@ class EventsStore(EventsWorkerStore):
metadata_json = encode_json( metadata_json = encode_json(
event.internal_metadata.get_dict() event.internal_metadata.get_dict()
).decode("UTF-8") )
sql = ( sql = (
"UPDATE event_json SET internal_metadata = ?" "UPDATE event_json SET internal_metadata = ?"
@ -1168,8 +1178,8 @@ class EventsStore(EventsWorkerStore):
"room_id": event.room_id, "room_id": event.room_id,
"internal_metadata": encode_json( "internal_metadata": encode_json(
event.internal_metadata.get_dict() event.internal_metadata.get_dict()
).decode("UTF-8"), ),
"json": encode_json(event_dict(event)).decode("UTF-8"), "json": encode_json(event_dict(event)),
} }
for event, _ in events_and_contexts for event, _ in events_and_contexts
], ],

View File

@ -24,7 +24,7 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.storage.events import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.caches import intern_string from synapse.util.caches import intern_string

View File

@ -74,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore):
txn (cursor): txn (cursor):
event_id (str): Id for the Event. event_id (str): Id for the Event.
Returns: Returns:
A dict of algorithm -> hash. A dict[unicode, bytes] of algorithm -> hash.
""" """
query = ( query = (
"SELECT algorithm, hash" "SELECT algorithm, hash"

View File

@ -43,7 +43,7 @@ from twisted.internet import defer
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.events import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background

View File

@ -137,7 +137,7 @@ class DomainSpecificString(
@classmethod @classmethod
def from_string(cls, s): def from_string(cls, s):
"""Parse the string given by 's' into a structure object.""" """Parse the string given by 's' into a structure object."""
if len(s) < 1 or s[0] != cls.SIGIL: if len(s) < 1 or s[0:1] != cls.SIGIL:
raise SynapseError(400, "Expected %s string to start with '%s'" % ( raise SynapseError(400, "Expected %s string to start with '%s'" % (
cls.__name__, cls.SIGIL, cls.__name__, cls.SIGIL,
)) ))

View File

@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate() # If we're passed a cache_context then we'll want to call its
# whenever we are invalidated # invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name] list_args = arg_dict[self.list_name]
# cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`)
results = {} results = {}
cached_defers = {}
missing = [] def update_results_dict(res, arg):
results[arg] = res
# list of deferreds to wait for
cached_defers = []
missing = set()
# If the cache takes a single arg then that is used as the key, # If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used. # otherwise a tuple is used.
if num_args == 1: if num_args == 1:
def cache_get(arg): def arg_to_cache_key(arg):
return cache.get(arg, callback=invalidate_callback) return arg
else: else:
key = list(keyargs) keylist = list(keyargs)
def cache_get(arg): def arg_to_cache_key(arg):
key[self.list_pos] = arg keylist[self.list_pos] = arg
return cache.get(tuple(key), callback=invalidate_callback) return tuple(keylist)
for arg in list_args: for arg in list_args:
try: try:
res = cache_get(arg) res = cache.get(arg_to_cache_key(arg),
callback=invalidate_callback)
if not isinstance(res, ObservableDeferred): if not isinstance(res, ObservableDeferred):
results[arg] = res results[arg] = res
elif not res.has_succeeded(): elif not res.has_succeeded():
res = res.observe() res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg) res.addCallback(update_results_dict, arg)
cached_defers[arg] = res cached_defers.append(res)
else: else:
results[arg] = res.get_result() results[arg] = res.get_result()
except KeyError: except KeyError:
missing.append(arg) missing.add(arg)
if missing: if missing:
args_to_call = dict(arg_dict) # we need an observable deferred for each entry in the list,
args_to_call[self.list_name] = missing # which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
for arg in missing:
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
observable = ObservableDeferred(deferred)
cache.set(key, observable, callback=invalidate_callback)
ret_d = defer.maybeDeferred( def complete_all(res):
# the wrapped function has completed. It returns a
# a dict. We can now resolve the observable deferreds in
# the cache and update our own result map.
for e in missing:
val = res.get(e, None)
deferreds_map[e].callback(val)
results[e] = val
def errback(f):
# the wrapped function has failed. Invalidate any cache
# entries we're supposed to be populating, and fail
# their deferreds.
for e in missing:
key = arg_to_cache_key(e)
cache.invalidate(key)
deferreds_map[e].errback(f)
# return the failure, to propagate to our caller.
return f
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = list(missing)
cached_defers.append(defer.maybeDeferred(
logcontext.preserve_fn(self.function_to_call), logcontext.preserve_fn(self.function_to_call),
**args_to_call **args_to_call
) ).addCallbacks(complete_all, errback))
ret_d = ObservableDeferred(ret_d)
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
if num_args == 1:
cache.set(
arg, observer,
callback=invalidate_callback
)
def invalidate(f, key):
cache.invalidate(key)
return f
observer.addErrback(invalidate, arg)
else:
key = list(keyargs)
key[self.list_pos] = arg
cache.set(
tuple(key), observer,
callback=invalidate_callback
)
def invalidate(f, key):
cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached_defers[arg] = res
if cached_defers: if cached_defers:
def update_results_dict(res): d = defer.gatherResults(
results.update(res) cached_defers,
return results
return logcontext.make_deferred_yieldable(defer.gatherResults(
list(cached_defers.values()),
consumeErrors=True, consumeErrors=True,
).addCallback(update_results_dict).addErrback( ).addCallbacks(
lambda _: results,
unwrapFirstError unwrapFirstError
)) )
return logcontext.make_deferred_yieldable(d)
else: else:
return results return results
@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
cache. cache.
Args: Args:
cache (Cache): The underlying cache to use. cached_method_name (str): The name of the single-item lookup method.
This is only used to find the cache to use.
list_name (str): The name of the argument that is the list to use to list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache. do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache num_args (int): Number of arguments to use as the key in the cache

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from six import string_types from six import binary_type, text_type
from canonicaljson import json from canonicaljson import json
from frozendict import frozendict from frozendict import frozendict
@ -26,7 +26,7 @@ def freeze(o):
if isinstance(o, frozendict): if isinstance(o, frozendict):
return o return o
if isinstance(o, string_types): if isinstance(o, (binary_type, text_type)):
return o return o
try: try:
@ -41,7 +41,7 @@ def unfreeze(o):
if isinstance(o, (dict, frozendict)): if isinstance(o, (dict, frozendict)):
return dict({k: unfreeze(v) for k, v in o.items()}) return dict({k: unfreeze(v) for k, v in o.items()})
if isinstance(o, string_types): if isinstance(o, (binary_type, text_type)):
return o return o
try: try:

View File

@ -46,7 +46,7 @@ class AuthTestCase(unittest.TestCase):
self.auth = Auth(self.hs) self.auth = Auth(self.hs)
self.test_user = "@foo:bar" self.test_user = "@foo:bar"
self.test_token = "_test_token_" self.test_token = b"_test_token_"
# this is overridden for the appservice tests # this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=user_info) self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -98,7 +98,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10" request.getClientIP.return_value = "192.168.10.10"
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42" request.getClientIP.return_value = "131.111.8.42"
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -141,7 +141,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -158,7 +158,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_valid_user_id(self): def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar", url="a_url", sender=self.test_user,
ip_range_whitelist=None, ip_range_whitelist=None,
@ -169,14 +169,17 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), masquerading_user_id) self.assertEquals(
requester.user.to_string(),
masquerading_user_id.decode('utf8')
)
def test_get_user_by_req_appservice_valid_token_bad_user_id(self): def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar", url="a_url", sender=self.test_user,
ip_range_whitelist=None, ip_range_whitelist=None,
@ -187,8 +190,8 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -418,7 +421,7 @@ class AuthTestCase(unittest.TestCase):
# check the token works # check the token works
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [token] request.args[b"access_token"] = [token.encode('ascii')]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
self.assertEqual(UserID.from_string(USER_ID), requester.user) self.assertEqual(UserID.from_string(USER_ID), requester.user)
@ -431,7 +434,7 @@ class AuthTestCase(unittest.TestCase):
# the token should *not* work now # the token should *not* work now
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [guest_tok] request.args[b"access_token"] = [guest_tok.encode('ascii')]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
with self.assertRaises(AuthError) as cm: with self.assertRaises(AuthError) as cm:

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mock import Mock
import pymacaroons import pymacaroons
@ -19,6 +20,7 @@ from twisted.internet import defer
import synapse import synapse
import synapse.api.errors import synapse.api.errors
from synapse.api.errors import AuthError
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
from tests import unittest from tests import unittest
@ -37,6 +39,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = AuthHandlers(self.hs) self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler self.auth_handler = self.hs.handlers.auth_handler
self.macaroon_generator = self.hs.get_macaroon_generator() self.macaroon_generator = self.hs.get_macaroon_generator()
# MAU tests
self.hs.config.max_mau_value = 50
self.small_number_of_users = 1
self.large_number_of_users = 100
def test_token_is_a_macaroon(self): def test_token_is_a_macaroon(self):
token = self.macaroon_generator.generate_access_token("some_user") token = self.macaroon_generator.generate_access_token("some_user")
@ -71,38 +77,37 @@ class AuthTestCase(unittest.TestCase):
v.satisfy_general(verify_nonce) v.satisfy_general(verify_nonce)
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
@defer.inlineCallbacks
def test_short_term_login_token_gives_user_id(self): def test_short_term_login_token_gives_user_id(self):
self.hs.clock.now = 1000 self.hs.clock.now = 1000
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
"a_user", 5000 "a_user", 5000
) )
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.assertEqual( token
"a_user",
self.auth_handler.validate_short_term_login_token_and_get_user_id(
token
)
) )
self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected # when we advance the clock, the token should be rejected
self.hs.clock.now = 6000 self.hs.clock.now = 6000
with self.assertRaises(synapse.api.errors.AuthError): with self.assertRaises(synapse.api.errors.AuthError):
self.auth_handler.validate_short_term_login_token_and_get_user_id( yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
token token
) )
@defer.inlineCallbacks
def test_short_term_login_token_cannot_replace_user_id(self): def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
"a_user", 5000 "a_user", 5000
) )
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)
self.assertEqual( self.assertEqual(
"a_user", "a_user", user_id
self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)
) )
# add another "user_id" caveat, which might allow us to override the # add another "user_id" caveat, which might allow us to override the
@ -110,6 +115,57 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("user_id = b_user") macaroon.add_first_party_caveat("user_id = b_user")
with self.assertRaises(synapse.api.errors.AuthError): with self.assertRaises(synapse.api.errors.AuthError):
self.auth_handler.validate_short_term_login_token_and_get_user_id( yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize() macaroon.serialize()
) )
@defer.inlineCallbacks
def test_mau_limits_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
yield self.auth_handler.get_access_token_for_user_id('user_a')
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
@defer.inlineCallbacks
def test_mau_limits_exceeded(self):
self.hs.config.limit_usage_by_mau = True
self.hs.get_datastore().count_monthly_users = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
with self.assertRaises(AuthError):
yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().count_monthly_users = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
with self.assertRaises(AuthError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
@defer.inlineCallbacks
def test_mau_limits_not_exceeded(self):
self.hs.config.limit_usage_by_mau = True
self.hs.get_datastore().count_monthly_users = Mock(
return_value=defer.succeed(self.small_number_of_users)
)
# Ensure does not raise exception
yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().count_monthly_users = Mock(
return_value=defer.succeed(self.small_number_of_users)
)
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
def _get_macaroon(self):
token = self.macaroon_generator.generate_short_term_login_token(
"user_a", 5000
)
return pymacaroons.Macaroon.deserialize(token)

View File

@ -17,6 +17,7 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import RegistrationError
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
@ -77,3 +78,53 @@ class RegistrationTestCase(unittest.TestCase):
requester, local_part, display_name) requester, local_part, display_name)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')
@defer.inlineCallbacks
def test_cannot_register_when_mau_limits_exceeded(self):
local_part = "someone"
display_name = "someone"
requester = create_requester("@as:test")
store = self.hs.get_datastore()
self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50
lots_of_users = 100
small_number_users = 1
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
# Ensure does not throw exception
yield self.handler.get_or_create_user(requester, 'a', display_name)
self.hs.config.limit_usage_by_mau = True
with self.assertRaises(RegistrationError):
yield self.handler.get_or_create_user(requester, 'b', display_name)
store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users))
self._macaroon_mock_generator("another_secret")
# Ensure does not throw exception
yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil")
self._macaroon_mock_generator("another another secret")
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
with self.assertRaises(RegistrationError):
yield self.handler.register(localpart=local_part)
self._macaroon_mock_generator("another another secret")
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
with self.assertRaises(RegistrationError):
yield self.handler.register_saml2(local_part)
def _macaroon_mock_generator(self, secret):
"""
Reset macaroon generator in the case where the test creates multiple users
"""
macaroon_generator = Mock(
generate_access_token=Mock(return_value=secret))
self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator)
self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler

View File

@ -44,7 +44,6 @@ def _expect_edu(destination, edu_type, content, origin="test"):
"content": content, "content": content,
} }
], ],
"pdu_failures": [],
} }

View File

@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import tests.utils
class InitTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(InitTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
hs.config.max_mau_value = 50
hs.config.limit_usage_by_mau = True
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def test_count_monthly_users(self):
count = yield self.store.count_monthly_users()
self.assertEqual(0, count)
yield self._insert_user_ips("@user:server1")
yield self._insert_user_ips("@user:server2")
count = yield self.store.count_monthly_users()
self.assertEqual(2, count)
@defer.inlineCallbacks
def _insert_user_ips(self, user):
"""
Helper function to populate user_ips without using batch insertion infra
args:
user (str): specify username i.e. @user:server.com
"""
yield self.store._simple_upsert(
table="user_ips",
keyvalues={
"user_id": user,
"access_token": "access_token",
"ip": "ip",
"user_agent": "user_agent",
"device_id": "device_id",
},
values={
"last_seen": self.clock.time_msec(),
}
)

View File

@ -273,3 +273,104 @@ class DescriptorTestCase(unittest.TestCase):
r = yield obj.fn(2, 3) r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips') self.assertEqual(r, 'chips')
obj.mock.assert_not_called() obj.mock.assert_not_called()
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached()
def fn(self, arg1, arg2):
pass
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
assert (
logcontext.LoggingContext.current_context().request == "c1"
)
# we want this to behave like an asynchronous function
yield run_on_reactor()
assert (
logcontext.LoggingContext.current_context().request == "c1"
)
defer.returnValue(self.mock(args1, arg2))
with logcontext.LoggingContext() as c1:
c1.request = "c1"
obj = Cls()
obj.mock.return_value = {10: 'fish', 20: 'chips'}
d1 = obj.list_fn([10, 20], 2)
self.assertEqual(
logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel,
)
r = yield d1
self.assertEqual(
logcontext.LoggingContext.current_context(),
c1
)
obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r, {10: 'fish', 20: 'chips'})
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = {30: 'peas'}
r = yield obj.list_fn([20, 30], 2)
obj.mock.assert_called_once_with([30], 2)
self.assertEqual(r, {20: 'chips', 30: 'peas'})
obj.mock.reset_mock()
# all the values should now be cached
r = yield obj.fn(10, 2)
self.assertEqual(r, 'fish')
r = yield obj.fn(20, 2)
self.assertEqual(r, 'chips')
r = yield obj.fn(30, 2)
self.assertEqual(r, 'peas')
r = yield obj.list_fn([10, 20, 30], 2)
obj.mock.assert_not_called()
self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
@defer.inlineCallbacks
def test_invalidate(self):
"""Make sure that invalidation callbacks are called."""
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached()
def fn(self, arg1, arg2):
pass
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
yield run_on_reactor()
defer.returnValue(self.mock(args1, arg2))
obj = Cls()
invalidate0 = mock.Mock()
invalidate1 = mock.Mock()
# cache miss
obj.mock.return_value = {10: 'fish', 20: 'chips'}
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r1, {10: 'fish', 20: 'chips'})
obj.mock.reset_mock()
# cache hit
r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
obj.mock.assert_not_called()
self.assertEqual(r2, {10: 'fish', 20: 'chips'})
invalidate0.assert_not_called()
invalidate1.assert_not_called()
# now if we invalidate the keys, both invalidations should get called
obj.fn.invalidate((10, 2))
invalidate0.assert_called_once()
invalidate1.assert_called_once()

View File

@ -193,7 +193,7 @@ class MockHttpResource(HttpServer):
self.prefix = prefix self.prefix = prefix
def trigger_get(self, path): def trigger_get(self, path):
return self.trigger("GET", path, None) return self.trigger(b"GET", path, None)
@patch('twisted.web.http.Request') @patch('twisted.web.http.Request')
@defer.inlineCallbacks @defer.inlineCallbacks
@ -227,7 +227,7 @@ class MockHttpResource(HttpServer):
headers = {} headers = {}
if federation_auth: if federation_auth:
headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="] headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it # return the right path if the event requires it
@ -241,6 +241,9 @@ class MockHttpResource(HttpServer):
except Exception: except Exception:
pass pass
if isinstance(path, bytes):
path = path.decode('utf8')
for (method, pattern, func) in self.callbacks: for (method, pattern, func) in self.callbacks:
if http_method != method: if http_method != method:
continue continue
@ -249,7 +252,7 @@ class MockHttpResource(HttpServer):
if matcher: if matcher:
try: try:
args = [ args = [
urlparse.unquote(u).decode("UTF-8") urlparse.unquote(u)
for u in matcher.groups() for u in matcher.groups()
] ]