mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-11 06:44:19 -05:00
Merge branch 'develop' of github.com:matrix-org/synapse into neilj/mau_tracker
This commit is contained in:
commit
d766f26de9
@ -63,3 +63,6 @@ Christoph Witzany <christoph at web.crofting.com>
|
|||||||
|
|
||||||
Pierre Jaury <pierre at jaury.eu>
|
Pierre Jaury <pierre at jaury.eu>
|
||||||
* Docker packaging
|
* Docker packaging
|
||||||
|
|
||||||
|
Serban Constantin <serban.constantin at gmail dot com>
|
||||||
|
* Small bug fix
|
22
Dockerfile
22
Dockerfile
@ -1,16 +1,32 @@
|
|||||||
FROM docker.io/python:2-alpine3.7
|
FROM docker.io/python:2-alpine3.7
|
||||||
|
|
||||||
RUN apk add --no-cache --virtual .nacl_deps su-exec build-base libffi-dev zlib-dev libressl-dev libjpeg-turbo-dev linux-headers postgresql-dev libxslt-dev
|
RUN apk add --no-cache --virtual .nacl_deps \
|
||||||
|
build-base \
|
||||||
|
libffi-dev \
|
||||||
|
libjpeg-turbo-dev \
|
||||||
|
libressl-dev \
|
||||||
|
libxslt-dev \
|
||||||
|
linux-headers \
|
||||||
|
postgresql-dev \
|
||||||
|
su-exec \
|
||||||
|
zlib-dev
|
||||||
|
|
||||||
COPY . /synapse
|
COPY . /synapse
|
||||||
|
|
||||||
# A wheel cache may be provided in ./cache for faster build
|
# A wheel cache may be provided in ./cache for faster build
|
||||||
RUN cd /synapse \
|
RUN cd /synapse \
|
||||||
&& pip install --upgrade pip setuptools psycopg2 lxml \
|
&& pip install --upgrade \
|
||||||
|
lxml \
|
||||||
|
pip \
|
||||||
|
psycopg2 \
|
||||||
|
setuptools \
|
||||||
&& mkdir -p /synapse/cache \
|
&& mkdir -p /synapse/cache \
|
||||||
&& pip install -f /synapse/cache --upgrade --process-dependency-links . \
|
&& pip install -f /synapse/cache --upgrade --process-dependency-links . \
|
||||||
&& mv /synapse/contrib/docker/start.py /synapse/contrib/docker/conf / \
|
&& mv /synapse/contrib/docker/start.py /synapse/contrib/docker/conf / \
|
||||||
&& rm -rf setup.py setup.cfg synapse
|
&& rm -rf \
|
||||||
|
setup.cfg \
|
||||||
|
setup.py \
|
||||||
|
synapse
|
||||||
|
|
||||||
VOLUME ["/data"]
|
VOLUME ["/data"]
|
||||||
|
|
||||||
|
1
changelog.d/2952.bugfix
Normal file
1
changelog.d/2952.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Make /directory/list API return 404 for room not found instead of 400
|
1
changelog.d/3384.misc
Normal file
1
changelog.d/3384.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Rewrite cache list decorator
|
1
changelog.d/3543.misc
Normal file
1
changelog.d/3543.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Improve Dockerfile and docker-compose instructions
|
1
changelog.d/3569.bugfix
Normal file
1
changelog.d/3569.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Unicode passwords are now normalised before hashing, preventing the instance where two different devices or browsers might send a different UTF-8 sequence for the password.
|
1
changelog.d/3612.misc
Normal file
1
changelog.d/3612.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Make EventStore inherit from EventFederationStore
|
1
changelog.d/3628.misc
Normal file
1
changelog.d/3628.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Remove unused field "pdu_failures" from transactions.
|
1
changelog.d/3630.feature
Normal file
1
changelog.d/3630.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add ability to limit number of monthly active users on the server
|
1
changelog.d/3634.misc
Normal file
1
changelog.d/3634.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
rename replication_layer to federation_client
|
@ -9,13 +9,7 @@ use that server.
|
|||||||
|
|
||||||
## Build
|
## Build
|
||||||
|
|
||||||
Build the docker image with the `docker build` command from the root of the synapse repository.
|
Build the docker image with the `docker-compose build` command.
|
||||||
|
|
||||||
```
|
|
||||||
docker build -t docker.io/matrixdotorg/synapse .
|
|
||||||
```
|
|
||||||
|
|
||||||
The `-t` option sets the image tag. Official images are tagged `matrixdotorg/synapse:<version>` where `<version>` is the same as the release tag in the synapse git repository.
|
|
||||||
|
|
||||||
You may have a local Python wheel cache available, in which case copy the relevant packages in the ``cache/`` directory at the root of the project.
|
You may have a local Python wheel cache available, in which case copy the relevant packages in the ``cache/`` directory at the root of the project.
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ version: '3'
|
|||||||
services:
|
services:
|
||||||
|
|
||||||
synapse:
|
synapse:
|
||||||
|
build: ../..
|
||||||
image: docker.io/matrixdotorg/synapse:latest
|
image: docker.io/matrixdotorg/synapse:latest
|
||||||
# Since snyapse does not retry to connect to the database, restart upon
|
# Since snyapse does not retry to connect to the database, restart upon
|
||||||
# failure
|
# failure
|
||||||
|
@ -252,10 +252,10 @@ class Auth(object):
|
|||||||
if ip_address not in app_service.ip_range_whitelist:
|
if ip_address not in app_service.ip_range_whitelist:
|
||||||
defer.returnValue((None, None))
|
defer.returnValue((None, None))
|
||||||
|
|
||||||
if "user_id" not in request.args:
|
if b"user_id" not in request.args:
|
||||||
defer.returnValue((app_service.sender, app_service))
|
defer.returnValue((app_service.sender, app_service))
|
||||||
|
|
||||||
user_id = request.args["user_id"][0]
|
user_id = request.args[b"user_id"][0].decode('utf8')
|
||||||
if app_service.sender == user_id:
|
if app_service.sender == user_id:
|
||||||
defer.returnValue((app_service.sender, app_service))
|
defer.returnValue((app_service.sender, app_service))
|
||||||
|
|
||||||
|
@ -55,6 +55,7 @@ class Codes(object):
|
|||||||
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||||
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
|
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
|
||||||
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
|
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
|
||||||
|
MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED"
|
||||||
|
|
||||||
|
|
||||||
class CodeMessageException(RuntimeError):
|
class CodeMessageException(RuntimeError):
|
||||||
|
@ -20,6 +20,8 @@ import sys
|
|||||||
|
|
||||||
from six import iteritems
|
from six import iteritems
|
||||||
|
|
||||||
|
from prometheus_client import Gauge
|
||||||
|
|
||||||
from twisted.application import service
|
from twisted.application import service
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.web.resource import EncodingResourceWrapper, NoResource
|
from twisted.web.resource import EncodingResourceWrapper, NoResource
|
||||||
@ -301,6 +303,11 @@ class SynapseHomeServer(HomeServer):
|
|||||||
quit_with_error(e.message)
|
quit_with_error(e.message)
|
||||||
|
|
||||||
|
|
||||||
|
# Gauges to expose monthly active user control metrics
|
||||||
|
current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU")
|
||||||
|
max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit")
|
||||||
|
|
||||||
|
|
||||||
def setup(config_options):
|
def setup(config_options):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -516,6 +523,18 @@ def run(hs):
|
|||||||
MonthlyActiveUsersStore(hs).reap_monthly_active_users, 1000 * 60 * 60
|
MonthlyActiveUsersStore(hs).reap_monthly_active_users, 1000 * 60 * 60
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def generate_monthly_active_users():
|
||||||
|
count = 0
|
||||||
|
if hs.config.limit_usage_by_mau:
|
||||||
|
count = yield hs.get_datastore().count_monthly_users()
|
||||||
|
current_mau_gauge.set(float(count))
|
||||||
|
max_mau_value_gauge.set(float(hs.config.max_mau_value))
|
||||||
|
|
||||||
|
generate_monthly_active_users()
|
||||||
|
if hs.config.limit_usage_by_mau:
|
||||||
|
clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
|
||||||
|
|
||||||
if hs.config.report_stats:
|
if hs.config.report_stats:
|
||||||
logger.info("Scheduling stats reporting for 3 hour intervals")
|
logger.info("Scheduling stats reporting for 3 hour intervals")
|
||||||
clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)
|
clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)
|
||||||
|
@ -67,6 +67,14 @@ class ServerConfig(Config):
|
|||||||
"block_non_admin_invites", False,
|
"block_non_admin_invites", False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Options to control access by tracking MAU
|
||||||
|
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
|
||||||
|
if self.limit_usage_by_mau:
|
||||||
|
self.max_mau_value = config.get(
|
||||||
|
"max_mau_value", 0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.max_mau_value = 0
|
||||||
# FIXME: federation_domain_whitelist needs sytests
|
# FIXME: federation_domain_whitelist needs sytests
|
||||||
self.federation_domain_whitelist = None
|
self.federation_domain_whitelist = None
|
||||||
federation_domain_whitelist = config.get(
|
federation_domain_whitelist = config.get(
|
||||||
|
@ -207,10 +207,6 @@ class FederationServer(FederationBase):
|
|||||||
edu.content
|
edu.content
|
||||||
)
|
)
|
||||||
|
|
||||||
pdu_failures = getattr(transaction, "pdu_failures", [])
|
|
||||||
for fail in pdu_failures:
|
|
||||||
logger.info("Got failure %r", fail)
|
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"pdus": pdu_results,
|
"pdus": pdu_results,
|
||||||
}
|
}
|
||||||
|
@ -62,8 +62,6 @@ class FederationRemoteSendQueue(object):
|
|||||||
|
|
||||||
self.edus = SortedDict() # stream position -> Edu
|
self.edus = SortedDict() # stream position -> Edu
|
||||||
|
|
||||||
self.failures = SortedDict() # stream position -> (destination, Failure)
|
|
||||||
|
|
||||||
self.device_messages = SortedDict() # stream position -> destination
|
self.device_messages = SortedDict() # stream position -> destination
|
||||||
|
|
||||||
self.pos = 1
|
self.pos = 1
|
||||||
@ -79,7 +77,7 @@ class FederationRemoteSendQueue(object):
|
|||||||
|
|
||||||
for queue_name in [
|
for queue_name in [
|
||||||
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
|
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
|
||||||
"edus", "failures", "device_messages", "pos_time",
|
"edus", "device_messages", "pos_time",
|
||||||
]:
|
]:
|
||||||
register(queue_name, getattr(self, queue_name))
|
register(queue_name, getattr(self, queue_name))
|
||||||
|
|
||||||
@ -149,12 +147,6 @@ class FederationRemoteSendQueue(object):
|
|||||||
for key in keys[:i]:
|
for key in keys[:i]:
|
||||||
del self.edus[key]
|
del self.edus[key]
|
||||||
|
|
||||||
# Delete things out of failure map
|
|
||||||
keys = self.failures.keys()
|
|
||||||
i = self.failures.bisect_left(position_to_delete)
|
|
||||||
for key in keys[:i]:
|
|
||||||
del self.failures[key]
|
|
||||||
|
|
||||||
# Delete things out of device map
|
# Delete things out of device map
|
||||||
keys = self.device_messages.keys()
|
keys = self.device_messages.keys()
|
||||||
i = self.device_messages.bisect_left(position_to_delete)
|
i = self.device_messages.bisect_left(position_to_delete)
|
||||||
@ -204,13 +196,6 @@ class FederationRemoteSendQueue(object):
|
|||||||
|
|
||||||
self.notifier.on_new_replication_data()
|
self.notifier.on_new_replication_data()
|
||||||
|
|
||||||
def send_failure(self, failure, destination):
|
|
||||||
"""As per TransactionQueue"""
|
|
||||||
pos = self._next_pos()
|
|
||||||
|
|
||||||
self.failures[pos] = (destination, str(failure))
|
|
||||||
self.notifier.on_new_replication_data()
|
|
||||||
|
|
||||||
def send_device_messages(self, destination):
|
def send_device_messages(self, destination):
|
||||||
"""As per TransactionQueue"""
|
"""As per TransactionQueue"""
|
||||||
pos = self._next_pos()
|
pos = self._next_pos()
|
||||||
@ -285,17 +270,6 @@ class FederationRemoteSendQueue(object):
|
|||||||
for (pos, edu) in edus:
|
for (pos, edu) in edus:
|
||||||
rows.append((pos, EduRow(edu)))
|
rows.append((pos, EduRow(edu)))
|
||||||
|
|
||||||
# Fetch changed failures
|
|
||||||
i = self.failures.bisect_right(from_token)
|
|
||||||
j = self.failures.bisect_right(to_token) + 1
|
|
||||||
failures = self.failures.items()[i:j]
|
|
||||||
|
|
||||||
for (pos, (destination, failure)) in failures:
|
|
||||||
rows.append((pos, FailureRow(
|
|
||||||
destination=destination,
|
|
||||||
failure=failure,
|
|
||||||
)))
|
|
||||||
|
|
||||||
# Fetch changed device messages
|
# Fetch changed device messages
|
||||||
i = self.device_messages.bisect_right(from_token)
|
i = self.device_messages.bisect_right(from_token)
|
||||||
j = self.device_messages.bisect_right(to_token) + 1
|
j = self.device_messages.bisect_right(to_token) + 1
|
||||||
@ -417,34 +391,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
|
|||||||
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
|
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
|
||||||
|
|
||||||
|
|
||||||
class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
|
|
||||||
"destination", # str
|
|
||||||
"failure",
|
|
||||||
))):
|
|
||||||
"""Streams failures to a remote server. Failures are issued when there was
|
|
||||||
something wrong with a transaction the remote sent us, e.g. it included
|
|
||||||
an event that was invalid.
|
|
||||||
"""
|
|
||||||
|
|
||||||
TypeId = "f"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_data(data):
|
|
||||||
return FailureRow(
|
|
||||||
destination=data["destination"],
|
|
||||||
failure=data["failure"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_data(self):
|
|
||||||
return {
|
|
||||||
"destination": self.destination,
|
|
||||||
"failure": self.failure,
|
|
||||||
}
|
|
||||||
|
|
||||||
def add_to_buffer(self, buff):
|
|
||||||
buff.failures.setdefault(self.destination, []).append(self.failure)
|
|
||||||
|
|
||||||
|
|
||||||
class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
|
class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
|
||||||
"destination", # str
|
"destination", # str
|
||||||
))):
|
))):
|
||||||
@ -471,7 +417,6 @@ TypeToRow = {
|
|||||||
PresenceRow,
|
PresenceRow,
|
||||||
KeyedEduRow,
|
KeyedEduRow,
|
||||||
EduRow,
|
EduRow,
|
||||||
FailureRow,
|
|
||||||
DeviceRow,
|
DeviceRow,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -481,7 +426,6 @@ ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
|
|||||||
"presence", # list(UserPresenceState)
|
"presence", # list(UserPresenceState)
|
||||||
"keyed_edus", # dict of destination -> { key -> Edu }
|
"keyed_edus", # dict of destination -> { key -> Edu }
|
||||||
"edus", # dict of destination -> [Edu]
|
"edus", # dict of destination -> [Edu]
|
||||||
"failures", # dict of destination -> [failures]
|
|
||||||
"device_destinations", # set of destinations
|
"device_destinations", # set of destinations
|
||||||
))
|
))
|
||||||
|
|
||||||
@ -503,7 +447,6 @@ def process_rows_for_federation(transaction_queue, rows):
|
|||||||
presence=[],
|
presence=[],
|
||||||
keyed_edus={},
|
keyed_edus={},
|
||||||
edus={},
|
edus={},
|
||||||
failures={},
|
|
||||||
device_destinations=set(),
|
device_destinations=set(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -532,9 +475,5 @@ def process_rows_for_federation(transaction_queue, rows):
|
|||||||
edu.destination, edu.edu_type, edu.content, key=None,
|
edu.destination, edu.edu_type, edu.content, key=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
for destination, failure_list in iteritems(buff.failures):
|
|
||||||
for failure in failure_list:
|
|
||||||
transaction_queue.send_failure(destination, failure)
|
|
||||||
|
|
||||||
for destination in buff.device_destinations:
|
for destination in buff.device_destinations:
|
||||||
transaction_queue.send_device_messages(destination)
|
transaction_queue.send_device_messages(destination)
|
||||||
|
@ -116,9 +116,6 @@ class TransactionQueue(object):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# destination -> list of tuple(failure, deferred)
|
|
||||||
self.pending_failures_by_dest = {}
|
|
||||||
|
|
||||||
# destination -> stream_id of last successfully sent to-device message.
|
# destination -> stream_id of last successfully sent to-device message.
|
||||||
# NB: may be a long or an int.
|
# NB: may be a long or an int.
|
||||||
self.last_device_stream_id_by_dest = {}
|
self.last_device_stream_id_by_dest = {}
|
||||||
@ -382,19 +379,6 @@ class TransactionQueue(object):
|
|||||||
|
|
||||||
self._attempt_new_transaction(destination)
|
self._attempt_new_transaction(destination)
|
||||||
|
|
||||||
def send_failure(self, failure, destination):
|
|
||||||
if destination == self.server_name or destination == "localhost":
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.can_send_to(destination):
|
|
||||||
return
|
|
||||||
|
|
||||||
self.pending_failures_by_dest.setdefault(
|
|
||||||
destination, []
|
|
||||||
).append(failure)
|
|
||||||
|
|
||||||
self._attempt_new_transaction(destination)
|
|
||||||
|
|
||||||
def send_device_messages(self, destination):
|
def send_device_messages(self, destination):
|
||||||
if destination == self.server_name or destination == "localhost":
|
if destination == self.server_name or destination == "localhost":
|
||||||
return
|
return
|
||||||
@ -469,7 +453,6 @@ class TransactionQueue(object):
|
|||||||
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
|
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
|
||||||
pending_edus = self.pending_edus_by_dest.pop(destination, [])
|
pending_edus = self.pending_edus_by_dest.pop(destination, [])
|
||||||
pending_presence = self.pending_presence_by_dest.pop(destination, {})
|
pending_presence = self.pending_presence_by_dest.pop(destination, {})
|
||||||
pending_failures = self.pending_failures_by_dest.pop(destination, [])
|
|
||||||
|
|
||||||
pending_edus.extend(
|
pending_edus.extend(
|
||||||
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
|
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
|
||||||
@ -497,7 +480,7 @@ class TransactionQueue(object):
|
|||||||
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
|
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
|
||||||
destination, len(pending_pdus))
|
destination, len(pending_pdus))
|
||||||
|
|
||||||
if not pending_pdus and not pending_edus and not pending_failures:
|
if not pending_pdus and not pending_edus:
|
||||||
logger.debug("TX [%s] Nothing to send", destination)
|
logger.debug("TX [%s] Nothing to send", destination)
|
||||||
self.last_device_stream_id_by_dest[destination] = (
|
self.last_device_stream_id_by_dest[destination] = (
|
||||||
device_stream_id
|
device_stream_id
|
||||||
@ -507,7 +490,7 @@ class TransactionQueue(object):
|
|||||||
# END CRITICAL SECTION
|
# END CRITICAL SECTION
|
||||||
|
|
||||||
success = yield self._send_new_transaction(
|
success = yield self._send_new_transaction(
|
||||||
destination, pending_pdus, pending_edus, pending_failures,
|
destination, pending_pdus, pending_edus,
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
sent_transactions_counter.inc()
|
sent_transactions_counter.inc()
|
||||||
@ -584,14 +567,12 @@ class TransactionQueue(object):
|
|||||||
|
|
||||||
@measure_func("_send_new_transaction")
|
@measure_func("_send_new_transaction")
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
|
def _send_new_transaction(self, destination, pending_pdus, pending_edus):
|
||||||
pending_failures):
|
|
||||||
|
|
||||||
# Sort based on the order field
|
# Sort based on the order field
|
||||||
pending_pdus.sort(key=lambda t: t[1])
|
pending_pdus.sort(key=lambda t: t[1])
|
||||||
pdus = [x[0] for x in pending_pdus]
|
pdus = [x[0] for x in pending_pdus]
|
||||||
edus = pending_edus
|
edus = pending_edus
|
||||||
failures = [x.get_dict() for x in pending_failures]
|
|
||||||
|
|
||||||
success = True
|
success = True
|
||||||
|
|
||||||
@ -601,11 +582,10 @@ class TransactionQueue(object):
|
|||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"TX [%s] {%s} Attempting new transaction"
|
"TX [%s] {%s} Attempting new transaction"
|
||||||
" (pdus: %d, edus: %d, failures: %d)",
|
" (pdus: %d, edus: %d)",
|
||||||
destination, txn_id,
|
destination, txn_id,
|
||||||
len(pdus),
|
len(pdus),
|
||||||
len(edus),
|
len(edus),
|
||||||
len(failures)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("TX [%s] Persisting transaction...", destination)
|
logger.debug("TX [%s] Persisting transaction...", destination)
|
||||||
@ -617,7 +597,6 @@ class TransactionQueue(object):
|
|||||||
destination=destination,
|
destination=destination,
|
||||||
pdus=pdus,
|
pdus=pdus,
|
||||||
edus=edus,
|
edus=edus,
|
||||||
pdu_failures=failures,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._next_txn_id += 1
|
self._next_txn_id += 1
|
||||||
@ -627,12 +606,11 @@ class TransactionQueue(object):
|
|||||||
logger.debug("TX [%s] Persisted transaction", destination)
|
logger.debug("TX [%s] Persisted transaction", destination)
|
||||||
logger.info(
|
logger.info(
|
||||||
"TX [%s] {%s} Sending transaction [%s],"
|
"TX [%s] {%s} Sending transaction [%s],"
|
||||||
" (PDUs: %d, EDUs: %d, failures: %d)",
|
" (PDUs: %d, EDUs: %d)",
|
||||||
destination, txn_id,
|
destination, txn_id,
|
||||||
transaction.transaction_id,
|
transaction.transaction_id,
|
||||||
len(pdus),
|
len(pdus),
|
||||||
len(edus),
|
len(edus),
|
||||||
len(failures),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Actually send the transaction
|
# Actually send the transaction
|
||||||
|
@ -165,7 +165,7 @@ def _parse_auth_header(header_bytes):
|
|||||||
param_dict = dict(kv.split("=") for kv in params)
|
param_dict = dict(kv.split("=") for kv in params)
|
||||||
|
|
||||||
def strip_quotes(value):
|
def strip_quotes(value):
|
||||||
if value.startswith(b"\""):
|
if value.startswith("\""):
|
||||||
return value[1:-1]
|
return value[1:-1]
|
||||||
else:
|
else:
|
||||||
return value
|
return value
|
||||||
@ -283,11 +283,10 @@ class FederationSendServlet(BaseFederationServlet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
|
"Received txn %s from %s. (PDUs: %d, EDUs: %d)",
|
||||||
transaction_id, origin,
|
transaction_id, origin,
|
||||||
len(transaction_data.get("pdus", [])),
|
len(transaction_data.get("pdus", [])),
|
||||||
len(transaction_data.get("edus", [])),
|
len(transaction_data.get("edus", [])),
|
||||||
len(transaction_data.get("failures", [])),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# We should ideally be getting this from the security layer.
|
# We should ideally be getting this from the security layer.
|
||||||
|
@ -73,7 +73,6 @@ class Transaction(JsonEncodedObject):
|
|||||||
"previous_ids",
|
"previous_ids",
|
||||||
"pdus",
|
"pdus",
|
||||||
"edus",
|
"edus",
|
||||||
"pdu_failures",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
internal_keys = [
|
internal_keys = [
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import bcrypt
|
import bcrypt
|
||||||
@ -519,6 +520,7 @@ class AuthHandler(BaseHandler):
|
|||||||
"""
|
"""
|
||||||
logger.info("Logging in user %s on device %s", user_id, device_id)
|
logger.info("Logging in user %s on device %s", user_id, device_id)
|
||||||
access_token = yield self.issue_access_token(user_id, device_id)
|
access_token = yield self.issue_access_token(user_id, device_id)
|
||||||
|
yield self._check_mau_limits()
|
||||||
|
|
||||||
# the device *should* have been registered before we got here; however,
|
# the device *should* have been registered before we got here; however,
|
||||||
# it's possible we raced against a DELETE operation. The thing we
|
# it's possible we raced against a DELETE operation. The thing we
|
||||||
@ -626,6 +628,7 @@ class AuthHandler(BaseHandler):
|
|||||||
# special case to check for "password" for the check_password interface
|
# special case to check for "password" for the check_password interface
|
||||||
# for the auth providers
|
# for the auth providers
|
||||||
password = login_submission.get("password")
|
password = login_submission.get("password")
|
||||||
|
|
||||||
if login_type == LoginType.PASSWORD:
|
if login_type == LoginType.PASSWORD:
|
||||||
if not self._password_enabled:
|
if not self._password_enabled:
|
||||||
raise SynapseError(400, "Password login has been disabled.")
|
raise SynapseError(400, "Password login has been disabled.")
|
||||||
@ -707,9 +710,10 @@ class AuthHandler(BaseHandler):
|
|||||||
multiple inexact matches.
|
multiple inexact matches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): complete @user:id
|
user_id (unicode): complete @user:id
|
||||||
|
password (unicode): the provided password
|
||||||
Returns:
|
Returns:
|
||||||
(str) the canonical_user_id, or None if unknown user / bad password
|
(unicode) the canonical_user_id, or None if unknown user / bad password
|
||||||
"""
|
"""
|
||||||
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
|
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
|
||||||
if not lookupres:
|
if not lookupres:
|
||||||
@ -728,15 +732,18 @@ class AuthHandler(BaseHandler):
|
|||||||
device_id)
|
device_id)
|
||||||
defer.returnValue(access_token)
|
defer.returnValue(access_token)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
||||||
|
yield self._check_mau_limits()
|
||||||
auth_api = self.hs.get_auth()
|
auth_api = self.hs.get_auth()
|
||||||
|
user_id = None
|
||||||
try:
|
try:
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(login_token)
|
macaroon = pymacaroons.Macaroon.deserialize(login_token)
|
||||||
user_id = auth_api.get_user_id_from_macaroon(macaroon)
|
user_id = auth_api.get_user_id_from_macaroon(macaroon)
|
||||||
auth_api.validate_macaroon(macaroon, "login", True, user_id)
|
auth_api.validate_macaroon(macaroon, "login", True, user_id)
|
||||||
return user_id
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
||||||
|
defer.returnValue(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_access_token(self, access_token):
|
def delete_access_token(self, access_token):
|
||||||
@ -849,14 +856,19 @@ class AuthHandler(BaseHandler):
|
|||||||
"""Computes a secure hash of password.
|
"""Computes a secure hash of password.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
password (str): Password to hash.
|
password (unicode): Password to hash.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred(str): Hashed password.
|
Deferred(unicode): Hashed password.
|
||||||
"""
|
"""
|
||||||
def _do_hash():
|
def _do_hash():
|
||||||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
# Normalise the Unicode in the password
|
||||||
bcrypt.gensalt(self.bcrypt_rounds))
|
pw = unicodedata.normalize("NFKC", password)
|
||||||
|
|
||||||
|
return bcrypt.hashpw(
|
||||||
|
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
|
||||||
|
bcrypt.gensalt(self.bcrypt_rounds),
|
||||||
|
).decode('ascii')
|
||||||
|
|
||||||
return make_deferred_yieldable(
|
return make_deferred_yieldable(
|
||||||
threads.deferToThreadPool(
|
threads.deferToThreadPool(
|
||||||
@ -868,16 +880,19 @@ class AuthHandler(BaseHandler):
|
|||||||
"""Validates that self.hash(password) == stored_hash.
|
"""Validates that self.hash(password) == stored_hash.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
password (str): Password to hash.
|
password (unicode): Password to hash.
|
||||||
stored_hash (str): Expected hash value.
|
stored_hash (unicode): Expected hash value.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred(bool): Whether self.hash(password) == stored_hash.
|
Deferred(bool): Whether self.hash(password) == stored_hash.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _do_validate_hash():
|
def _do_validate_hash():
|
||||||
|
# Normalise the Unicode in the password
|
||||||
|
pw = unicodedata.normalize("NFKC", password)
|
||||||
|
|
||||||
return bcrypt.checkpw(
|
return bcrypt.checkpw(
|
||||||
password.encode('utf8') + self.hs.config.password_pepper,
|
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
|
||||||
stored_hash.encode('utf8')
|
stored_hash.encode('utf8')
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -892,6 +907,19 @@ class AuthHandler(BaseHandler):
|
|||||||
else:
|
else:
|
||||||
return defer.succeed(False)
|
return defer.succeed(False)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _check_mau_limits(self):
|
||||||
|
"""
|
||||||
|
Ensure that if mau blocking is enabled that invalid users cannot
|
||||||
|
log in.
|
||||||
|
"""
|
||||||
|
if self.hs.config.limit_usage_by_mau is True:
|
||||||
|
current_mau = yield self.store.count_monthly_users()
|
||||||
|
if current_mau >= self.hs.config.max_mau_value:
|
||||||
|
raise AuthError(
|
||||||
|
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
class MacaroonGenerator(object):
|
class MacaroonGenerator(object):
|
||||||
|
@ -76,7 +76,7 @@ class FederationHandler(BaseHandler):
|
|||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.replication_layer = hs.get_federation_client()
|
self.federation_client = hs.get_federation_client()
|
||||||
self.state_handler = hs.get_state_handler()
|
self.state_handler = hs.get_state_handler()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.keyring = hs.get_keyring()
|
self.keyring = hs.get_keyring()
|
||||||
@ -255,7 +255,7 @@ class FederationHandler(BaseHandler):
|
|||||||
# know about
|
# know about
|
||||||
for p in prevs - seen:
|
for p in prevs - seen:
|
||||||
state, got_auth_chain = (
|
state, got_auth_chain = (
|
||||||
yield self.replication_layer.get_state_for_room(
|
yield self.federation_client.get_state_for_room(
|
||||||
origin, pdu.room_id, p
|
origin, pdu.room_id, p
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -338,7 +338,7 @@ class FederationHandler(BaseHandler):
|
|||||||
#
|
#
|
||||||
# see https://github.com/matrix-org/synapse/pull/1744
|
# see https://github.com/matrix-org/synapse/pull/1744
|
||||||
|
|
||||||
missing_events = yield self.replication_layer.get_missing_events(
|
missing_events = yield self.federation_client.get_missing_events(
|
||||||
origin,
|
origin,
|
||||||
pdu.room_id,
|
pdu.room_id,
|
||||||
earliest_events_ids=list(latest),
|
earliest_events_ids=list(latest),
|
||||||
@ -522,7 +522,7 @@ class FederationHandler(BaseHandler):
|
|||||||
if dest == self.server_name:
|
if dest == self.server_name:
|
||||||
raise SynapseError(400, "Can't backfill from self.")
|
raise SynapseError(400, "Can't backfill from self.")
|
||||||
|
|
||||||
events = yield self.replication_layer.backfill(
|
events = yield self.federation_client.backfill(
|
||||||
dest,
|
dest,
|
||||||
room_id,
|
room_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@ -570,7 +570,7 @@ class FederationHandler(BaseHandler):
|
|||||||
state_events = {}
|
state_events = {}
|
||||||
events_to_state = {}
|
events_to_state = {}
|
||||||
for e_id in edges:
|
for e_id in edges:
|
||||||
state, auth = yield self.replication_layer.get_state_for_room(
|
state, auth = yield self.federation_client.get_state_for_room(
|
||||||
destination=dest,
|
destination=dest,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
event_id=e_id
|
event_id=e_id
|
||||||
@ -612,7 +612,7 @@ class FederationHandler(BaseHandler):
|
|||||||
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||||
[
|
[
|
||||||
logcontext.run_in_background(
|
logcontext.run_in_background(
|
||||||
self.replication_layer.get_pdu,
|
self.federation_client.get_pdu,
|
||||||
[dest],
|
[dest],
|
||||||
event_id,
|
event_id,
|
||||||
outlier=True,
|
outlier=True,
|
||||||
@ -893,7 +893,7 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
Invites must be signed by the invitee's server before distribution.
|
Invites must be signed by the invitee's server before distribution.
|
||||||
"""
|
"""
|
||||||
pdu = yield self.replication_layer.send_invite(
|
pdu = yield self.federation_client.send_invite(
|
||||||
destination=target_host,
|
destination=target_host,
|
||||||
room_id=event.room_id,
|
room_id=event.room_id,
|
||||||
event_id=event.event_id,
|
event_id=event.event_id,
|
||||||
@ -955,7 +955,7 @@ class FederationHandler(BaseHandler):
|
|||||||
target_hosts.insert(0, origin)
|
target_hosts.insert(0, origin)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
ret = yield self.replication_layer.send_join(target_hosts, event)
|
ret = yield self.federation_client.send_join(target_hosts, event)
|
||||||
|
|
||||||
origin = ret["origin"]
|
origin = ret["origin"]
|
||||||
state = ret["state"]
|
state = ret["state"]
|
||||||
@ -1211,7 +1211,7 @@ class FederationHandler(BaseHandler):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
yield self.replication_layer.send_leave(
|
yield self.federation_client.send_leave(
|
||||||
target_hosts,
|
target_hosts,
|
||||||
event
|
event
|
||||||
)
|
)
|
||||||
@ -1234,7 +1234,7 @@ class FederationHandler(BaseHandler):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
|
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
|
||||||
content={},):
|
content={},):
|
||||||
origin, pdu = yield self.replication_layer.make_membership_event(
|
origin, pdu = yield self.federation_client.make_membership_event(
|
||||||
target_hosts,
|
target_hosts,
|
||||||
room_id,
|
room_id,
|
||||||
user_id,
|
user_id,
|
||||||
@ -1567,7 +1567,7 @@ class FederationHandler(BaseHandler):
|
|||||||
missing_auth_events.add(e_id)
|
missing_auth_events.add(e_id)
|
||||||
|
|
||||||
for e_id in missing_auth_events:
|
for e_id in missing_auth_events:
|
||||||
m_ev = yield self.replication_layer.get_pdu(
|
m_ev = yield self.federation_client.get_pdu(
|
||||||
[origin],
|
[origin],
|
||||||
e_id,
|
e_id,
|
||||||
outlier=True,
|
outlier=True,
|
||||||
@ -1777,7 +1777,7 @@ class FederationHandler(BaseHandler):
|
|||||||
logger.info("Missing auth: %s", missing_auth)
|
logger.info("Missing auth: %s", missing_auth)
|
||||||
# If we don't have all the auth events, we need to get them.
|
# If we don't have all the auth events, we need to get them.
|
||||||
try:
|
try:
|
||||||
remote_auth_chain = yield self.replication_layer.get_event_auth(
|
remote_auth_chain = yield self.federation_client.get_event_auth(
|
||||||
origin, event.room_id, event.event_id
|
origin, event.room_id, event.event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1893,7 +1893,7 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 2. Get remote difference.
|
# 2. Get remote difference.
|
||||||
result = yield self.replication_layer.query_auth(
|
result = yield self.federation_client.query_auth(
|
||||||
origin,
|
origin,
|
||||||
event.room_id,
|
event.room_id,
|
||||||
event.event_id,
|
event.event_id,
|
||||||
@ -2192,7 +2192,7 @@ class FederationHandler(BaseHandler):
|
|||||||
yield member_handler.send_membership_event(None, event, context)
|
yield member_handler.send_membership_event(None, event, context)
|
||||||
else:
|
else:
|
||||||
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
|
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
|
||||||
yield self.replication_layer.forward_third_party_invite(
|
yield self.federation_client.forward_third_party_invite(
|
||||||
destinations,
|
destinations,
|
||||||
room_id,
|
room_id,
|
||||||
event_dict,
|
event_dict,
|
||||||
|
@ -45,7 +45,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
hs (synapse.server.HomeServer):
|
hs (synapse.server.HomeServer):
|
||||||
"""
|
"""
|
||||||
super(RegistrationHandler, self).__init__(hs)
|
super(RegistrationHandler, self).__init__(hs)
|
||||||
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self.profile_handler = hs.get_profile_handler()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
@ -131,7 +131,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
Args:
|
Args:
|
||||||
localpart : The local part of the user ID to register. If None,
|
localpart : The local part of the user ID to register. If None,
|
||||||
one will be generated.
|
one will be generated.
|
||||||
password (str) : The password to assign to this user so they can
|
password (unicode) : The password to assign to this user so they can
|
||||||
login again. This can be None which means they cannot login again
|
login again. This can be None which means they cannot login again
|
||||||
via a password (e.g. the user is an application service user).
|
via a password (e.g. the user is an application service user).
|
||||||
generate_token (bool): Whether a new access token should be
|
generate_token (bool): Whether a new access token should be
|
||||||
@ -144,6 +144,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
Raises:
|
Raises:
|
||||||
RegistrationError if there was a problem registering.
|
RegistrationError if there was a problem registering.
|
||||||
"""
|
"""
|
||||||
|
yield self._check_mau_limits()
|
||||||
password_hash = None
|
password_hash = None
|
||||||
if password:
|
if password:
|
||||||
password_hash = yield self.auth_handler().hash(password)
|
password_hash = yield self.auth_handler().hash(password)
|
||||||
@ -288,6 +289,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
400,
|
400,
|
||||||
"User ID can only contain characters a-z, 0-9, or '=_-./'",
|
"User ID can only contain characters a-z, 0-9, or '=_-./'",
|
||||||
)
|
)
|
||||||
|
yield self._check_mau_limits()
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
@ -437,7 +439,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
"""
|
"""
|
||||||
if localpart is None:
|
if localpart is None:
|
||||||
raise SynapseError(400, "Request must include user id")
|
raise SynapseError(400, "Request must include user id")
|
||||||
|
yield self._check_mau_limits()
|
||||||
need_register = True
|
need_register = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -531,3 +533,16 @@ class RegistrationHandler(BaseHandler):
|
|||||||
remote_room_hosts=remote_room_hosts,
|
remote_room_hosts=remote_room_hosts,
|
||||||
action="join",
|
action="join",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _check_mau_limits(self):
|
||||||
|
"""
|
||||||
|
Do not accept registrations if monthly active user limits exceeded
|
||||||
|
and limiting is enabled
|
||||||
|
"""
|
||||||
|
if self.hs.config.limit_usage_by_mau is True:
|
||||||
|
current_mau = yield self.store.count_monthly_users()
|
||||||
|
if current_mau >= self.hs.config.max_mau_value:
|
||||||
|
raise RegistrationError(
|
||||||
|
403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED
|
||||||
|
)
|
||||||
|
@ -13,12 +13,13 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import cgi
|
import cgi
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
|
||||||
|
|
||||||
from six.moves import http_client
|
from six import PY3
|
||||||
|
from six.moves import http_client, urllib
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
|
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
|
||||||
|
|
||||||
@ -264,6 +265,7 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
def register_paths(self, method, path_patterns, callback):
|
def register_paths(self, method, path_patterns, callback):
|
||||||
|
method = method.encode("utf-8") # method is bytes on py3
|
||||||
for path_pattern in path_patterns:
|
for path_pattern in path_patterns:
|
||||||
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
||||||
self.path_regexs.setdefault(method, []).append(
|
self.path_regexs.setdefault(method, []).append(
|
||||||
@ -296,8 +298,19 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
# here. If it throws an exception, that is handled by the wrapper
|
# here. If it throws an exception, that is handled by the wrapper
|
||||||
# installed by @request_handler.
|
# installed by @request_handler.
|
||||||
|
|
||||||
|
def _unquote(s):
|
||||||
|
if PY3:
|
||||||
|
# On Python 3, unquote is unicode -> unicode
|
||||||
|
return urllib.parse.unquote(s)
|
||||||
|
else:
|
||||||
|
# On Python 2, unquote is bytes -> bytes We need to encode the
|
||||||
|
# URL again (as it was decoded by _get_handler_for request), as
|
||||||
|
# ASCII because it's a URL, and then decode it to get the UTF-8
|
||||||
|
# characters that were quoted.
|
||||||
|
return urllib.parse.unquote(s.encode('ascii')).decode('utf8')
|
||||||
|
|
||||||
kwargs = intern_dict({
|
kwargs = intern_dict({
|
||||||
name: urllib.unquote(value).decode("UTF-8") if value else value
|
name: _unquote(value) if value else value
|
||||||
for name, value in group_dict.items()
|
for name, value in group_dict.items()
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -313,9 +326,9 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
request (twisted.web.http.Request):
|
request (twisted.web.http.Request):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Callable, dict[str, str]]: callback method, and the dict
|
Tuple[Callable, dict[unicode, unicode]]: callback method, and the
|
||||||
mapping keys to path components as specified in the handler's
|
dict mapping keys to path components as specified in the
|
||||||
path match regexp.
|
handler's path match regexp.
|
||||||
|
|
||||||
The callback will normally be a method registered via
|
The callback will normally be a method registered via
|
||||||
register_paths, so will return (possibly via Deferred) either
|
register_paths, so will return (possibly via Deferred) either
|
||||||
@ -327,7 +340,7 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
# Loop through all the registered callbacks to check if the method
|
# Loop through all the registered callbacks to check if the method
|
||||||
# and path regex match
|
# and path regex match
|
||||||
for path_entry in self.path_regexs.get(request.method, []):
|
for path_entry in self.path_regexs.get(request.method, []):
|
||||||
m = path_entry.pattern.match(request.path)
|
m = path_entry.pattern.match(request.path.decode('ascii'))
|
||||||
if m:
|
if m:
|
||||||
# We found a match!
|
# We found a match!
|
||||||
return path_entry.callback, m.groupdict()
|
return path_entry.callback, m.groupdict()
|
||||||
@ -383,7 +396,7 @@ class RootRedirect(resource.Resource):
|
|||||||
self.url = path
|
self.url = path
|
||||||
|
|
||||||
def render_GET(self, request):
|
def render_GET(self, request):
|
||||||
return redirectTo(self.url, request)
|
return redirectTo(self.url.encode('ascii'), request)
|
||||||
|
|
||||||
def getChild(self, name, request):
|
def getChild(self, name, request):
|
||||||
if len(name) == 0:
|
if len(name) == 0:
|
||||||
@ -404,12 +417,14 @@ def respond_with_json(request, code, json_object, send_cors=False,
|
|||||||
return
|
return
|
||||||
|
|
||||||
if pretty_print:
|
if pretty_print:
|
||||||
json_bytes = encode_pretty_printed_json(json_object) + "\n"
|
json_bytes = (encode_pretty_printed_json(json_object) + "\n"
|
||||||
|
).encode("utf-8")
|
||||||
else:
|
else:
|
||||||
if canonical_json or synapse.events.USE_FROZEN_DICTS:
|
if canonical_json or synapse.events.USE_FROZEN_DICTS:
|
||||||
|
# canonicaljson already encodes to bytes
|
||||||
json_bytes = encode_canonical_json(json_object)
|
json_bytes = encode_canonical_json(json_object)
|
||||||
else:
|
else:
|
||||||
json_bytes = json.dumps(json_object)
|
json_bytes = json.dumps(json_object).encode("utf-8")
|
||||||
|
|
||||||
return respond_with_json_bytes(
|
return respond_with_json_bytes(
|
||||||
request, code, json_bytes,
|
request, code, json_bytes,
|
||||||
|
@ -171,8 +171,16 @@ def parse_json_value_from_request(request, allow_empty_body=False):
|
|||||||
if not content_bytes and allow_empty_body:
|
if not content_bytes and allow_empty_body:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Decode to Unicode so that simplejson will return Unicode strings on
|
||||||
|
# Python 2
|
||||||
try:
|
try:
|
||||||
content = json.loads(content_bytes)
|
content_unicode = content_bytes.decode('utf8')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
logger.warn("Unable to decode UTF-8")
|
||||||
|
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = json.loads(content_unicode)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn("Unable to parse JSON: %s", e)
|
logger.warn("Unable to parse JSON: %s", e)
|
||||||
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
||||||
|
@ -18,6 +18,7 @@ import hashlib
|
|||||||
import hmac
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from six import text_type
|
||||||
from six.moves import http_client
|
from six.moves import http_client
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
@ -131,7 +132,10 @@ class UserRegisterServlet(ClientV1RestServlet):
|
|||||||
400, "username must be specified", errcode=Codes.BAD_JSON,
|
400, "username must be specified", errcode=Codes.BAD_JSON,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if (not isinstance(body['username'], str) or len(body['username']) > 512):
|
if (
|
||||||
|
not isinstance(body['username'], text_type)
|
||||||
|
or len(body['username']) > 512
|
||||||
|
):
|
||||||
raise SynapseError(400, "Invalid username")
|
raise SynapseError(400, "Invalid username")
|
||||||
|
|
||||||
username = body["username"].encode("utf-8")
|
username = body["username"].encode("utf-8")
|
||||||
@ -143,7 +147,10 @@ class UserRegisterServlet(ClientV1RestServlet):
|
|||||||
400, "password must be specified", errcode=Codes.BAD_JSON,
|
400, "password must be specified", errcode=Codes.BAD_JSON,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if (not isinstance(body['password'], str) or len(body['password']) > 512):
|
if (
|
||||||
|
not isinstance(body['password'], text_type)
|
||||||
|
or len(body['password']) > 512
|
||||||
|
):
|
||||||
raise SynapseError(400, "Invalid password")
|
raise SynapseError(400, "Invalid password")
|
||||||
|
|
||||||
password = body["password"].encode("utf-8")
|
password = body["password"].encode("utf-8")
|
||||||
@ -166,17 +173,18 @@ class UserRegisterServlet(ClientV1RestServlet):
|
|||||||
want_mac.update(b"admin" if admin else b"notadmin")
|
want_mac.update(b"admin" if admin else b"notadmin")
|
||||||
want_mac = want_mac.hexdigest()
|
want_mac = want_mac.hexdigest()
|
||||||
|
|
||||||
if not hmac.compare_digest(want_mac, got_mac):
|
if not hmac.compare_digest(want_mac, got_mac.encode('ascii')):
|
||||||
raise SynapseError(
|
raise SynapseError(403, "HMAC incorrect")
|
||||||
403, "HMAC incorrect",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reuse the parts of RegisterRestServlet to reduce code duplication
|
# Reuse the parts of RegisterRestServlet to reduce code duplication
|
||||||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||||
|
|
||||||
register = RegisterRestServlet(self.hs)
|
register = RegisterRestServlet(self.hs)
|
||||||
|
|
||||||
(user_id, _) = yield register.registration_handler.register(
|
(user_id, _) = yield register.registration_handler.register(
|
||||||
localpart=username.lower(), password=password, admin=bool(admin),
|
localpart=body['username'].lower(),
|
||||||
|
password=body["password"],
|
||||||
|
admin=bool(admin),
|
||||||
generate_token=False,
|
generate_token=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ import logging
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import AuthError, Codes, SynapseError
|
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
from synapse.types import RoomAlias
|
from synapse.types import RoomAlias
|
||||||
|
|
||||||
@ -159,7 +159,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
|
|||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
room = yield self.store.get_room(room_id)
|
room = yield self.store.get_room(room_id)
|
||||||
if room is None:
|
if room is None:
|
||||||
raise SynapseError(400, "Unknown room")
|
raise NotFoundError("Unknown room")
|
||||||
|
|
||||||
defer.returnValue((200, {
|
defer.returnValue((200, {
|
||||||
"visibility": "public" if room["is_public"] else "private"
|
"visibility": "public" if room["is_public"] else "private"
|
||||||
|
@ -193,15 +193,15 @@ class RegisterRestServlet(RestServlet):
|
|||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
kind = "user"
|
kind = b"user"
|
||||||
if "kind" in request.args:
|
if b"kind" in request.args:
|
||||||
kind = request.args["kind"][0]
|
kind = request.args[b"kind"][0]
|
||||||
|
|
||||||
if kind == "guest":
|
if kind == b"guest":
|
||||||
ret = yield self._do_guest_registration(body)
|
ret = yield self._do_guest_registration(body)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
return
|
return
|
||||||
elif kind != "user":
|
elif kind != b"user":
|
||||||
raise UnrecognizedRequestError(
|
raise UnrecognizedRequestError(
|
||||||
"Do not understand membership kind: %s" % (kind,)
|
"Do not understand membership kind: %s" % (kind,)
|
||||||
)
|
)
|
||||||
@ -389,8 +389,8 @@ class RegisterRestServlet(RestServlet):
|
|||||||
assert_params_in_dict(params, ["password"])
|
assert_params_in_dict(params, ["password"])
|
||||||
|
|
||||||
desired_username = params.get("username", None)
|
desired_username = params.get("username", None)
|
||||||
new_password = params.get("password", None)
|
|
||||||
guest_access_token = params.get("guest_access_token", None)
|
guest_access_token = params.get("guest_access_token", None)
|
||||||
|
new_password = params.get("password", None)
|
||||||
|
|
||||||
if desired_username is not None:
|
if desired_username is not None:
|
||||||
desired_username = desired_username.lower()
|
desired_username = desired_username.lower()
|
||||||
|
@ -177,7 +177,7 @@ class MediaStorage(object):
|
|||||||
if res:
|
if res:
|
||||||
with res:
|
with res:
|
||||||
consumer = BackgroundFileConsumer(
|
consumer = BackgroundFileConsumer(
|
||||||
open(local_path, "w"), self.hs.get_reactor())
|
open(local_path, "wb"), self.hs.get_reactor())
|
||||||
yield res.write_to_consumer(consumer)
|
yield res.write_to_consumer(consumer)
|
||||||
yield consumer.wait()
|
yield consumer.wait()
|
||||||
defer.returnValue(local_path)
|
defer.returnValue(local_path)
|
||||||
|
@ -577,7 +577,7 @@ def _make_state_cache_entry(
|
|||||||
|
|
||||||
def _ordered_events(events):
|
def _ordered_events(events):
|
||||||
def key_func(e):
|
def key_func(e):
|
||||||
return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
|
return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
|
||||||
|
|
||||||
return sorted(events, key=key_func)
|
return sorted(events, key=key_func)
|
||||||
|
|
||||||
|
@ -66,6 +66,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
PresenceStore, TransactionStore,
|
PresenceStore, TransactionStore,
|
||||||
DirectoryStore, KeyStore, StateStore, SignatureStore,
|
DirectoryStore, KeyStore, StateStore, SignatureStore,
|
||||||
ApplicationServiceStore,
|
ApplicationServiceStore,
|
||||||
|
EventsStore,
|
||||||
EventFederationStore,
|
EventFederationStore,
|
||||||
MediaRepositoryStore,
|
MediaRepositoryStore,
|
||||||
RejectionsStore,
|
RejectionsStore,
|
||||||
@ -73,7 +74,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
PusherStore,
|
PusherStore,
|
||||||
PushRuleStore,
|
PushRuleStore,
|
||||||
ApplicationServiceTransactionStore,
|
ApplicationServiceTransactionStore,
|
||||||
EventsStore,
|
|
||||||
ReceiptsStore,
|
ReceiptsStore,
|
||||||
EndToEndKeyStore,
|
EndToEndKeyStore,
|
||||||
SearchStore,
|
SearchStore,
|
||||||
@ -94,6 +94,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self.database_engine = hs.database_engine
|
self.database_engine = hs.database_engine
|
||||||
|
|
||||||
|
self.db_conn = db_conn
|
||||||
self._stream_id_gen = StreamIdGenerator(
|
self._stream_id_gen = StreamIdGenerator(
|
||||||
db_conn, "events", "stream_ordering",
|
db_conn, "events", "stream_ordering",
|
||||||
extra_tables=[("local_invites", "stream_id")]
|
extra_tables=[("local_invites", "stream_id")]
|
||||||
@ -266,6 +267,31 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
|
|
||||||
return self.runInteraction("count_users", _count_users)
|
return self.runInteraction("count_users", _count_users)
|
||||||
|
|
||||||
|
def count_monthly_users(self):
|
||||||
|
"""Counts the number of users who used this homeserver in the last 30 days
|
||||||
|
|
||||||
|
This method should be refactored with count_daily_users - the only
|
||||||
|
reason not to is waiting on definition of mau
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Defered[int]
|
||||||
|
"""
|
||||||
|
def _count_monthly_users(txn):
|
||||||
|
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
|
||||||
|
sql = """
|
||||||
|
SELECT COALESCE(count(*), 0) FROM (
|
||||||
|
SELECT user_id FROM user_ips
|
||||||
|
WHERE last_seen > ?
|
||||||
|
GROUP BY user_id
|
||||||
|
) u
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (thirty_days_ago,))
|
||||||
|
count, = txn.fetchone()
|
||||||
|
return count
|
||||||
|
|
||||||
|
return self.runInteraction("count_monthly_users", _count_monthly_users)
|
||||||
|
|
||||||
def count_r30_users(self):
|
def count_r30_users(self):
|
||||||
"""
|
"""
|
||||||
Counts the number of 30 day retained users, defined as:-
|
Counts the number of 30 day retained users, defined as:-
|
||||||
|
@ -22,7 +22,7 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
from synapse.appservice import AppServiceTransaction
|
from synapse.appservice import AppServiceTransaction
|
||||||
from synapse.config.appservice import load_appservices
|
from synapse.config.appservice import load_appservices
|
||||||
from synapse.storage.events import EventsWorkerStore
|
from synapse.storage.events_worker import EventsWorkerStore
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ from twisted.internet import defer
|
|||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.events import EventsWorkerStore
|
from synapse.storage.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.signatures import SignatureWorkerStore
|
from synapse.storage.signatures import SignatureWorkerStore
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
|
@ -34,6 +34,8 @@ from synapse.api.errors import SynapseError
|
|||||||
from synapse.events import EventBase # noqa: F401
|
from synapse.events import EventBase # noqa: F401
|
||||||
from synapse.events.snapshot import EventContext # noqa: F401
|
from synapse.events.snapshot import EventContext # noqa: F401
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||||
|
from synapse.storage.event_federation import EventFederationStore
|
||||||
from synapse.storage.events_worker import EventsWorkerStore
|
from synapse.storage.events_worker import EventsWorkerStore
|
||||||
from synapse.types import RoomStreamToken, get_domain_from_id
|
from synapse.types import RoomStreamToken, get_domain_from_id
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
@ -65,7 +67,13 @@ state_delta_reuse_delta_counter = Counter(
|
|||||||
|
|
||||||
|
|
||||||
def encode_json(json_object):
|
def encode_json(json_object):
|
||||||
return frozendict_json_encoder.encode(json_object)
|
"""
|
||||||
|
Encode a Python object as JSON and return it in a Unicode string.
|
||||||
|
"""
|
||||||
|
out = frozendict_json_encoder.encode(json_object)
|
||||||
|
if isinstance(out, bytes):
|
||||||
|
out = out.decode('utf8')
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class _EventPeristenceQueue(object):
|
class _EventPeristenceQueue(object):
|
||||||
@ -193,7 +201,9 @@ def _retry_on_integrity_error(func):
|
|||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
class EventsStore(EventsWorkerStore):
|
# inherits from EventFederationStore so that we can call _update_backward_extremities
|
||||||
|
# and _handle_mult_prev_events (though arguably those could both be moved in here)
|
||||||
|
class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore):
|
||||||
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
|
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
|
||||||
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
|
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
|
||||||
|
|
||||||
@ -1054,7 +1064,7 @@ class EventsStore(EventsWorkerStore):
|
|||||||
|
|
||||||
metadata_json = encode_json(
|
metadata_json = encode_json(
|
||||||
event.internal_metadata.get_dict()
|
event.internal_metadata.get_dict()
|
||||||
).decode("UTF-8")
|
)
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"UPDATE event_json SET internal_metadata = ?"
|
"UPDATE event_json SET internal_metadata = ?"
|
||||||
@ -1168,8 +1178,8 @@ class EventsStore(EventsWorkerStore):
|
|||||||
"room_id": event.room_id,
|
"room_id": event.room_id,
|
||||||
"internal_metadata": encode_json(
|
"internal_metadata": encode_json(
|
||||||
event.internal_metadata.get_dict()
|
event.internal_metadata.get_dict()
|
||||||
).decode("UTF-8"),
|
),
|
||||||
"json": encode_json(event_dict(event)).decode("UTF-8"),
|
"json": encode_json(event_dict(event)),
|
||||||
}
|
}
|
||||||
for event, _ in events_and_contexts
|
for event, _ in events_and_contexts
|
||||||
],
|
],
|
||||||
|
@ -24,7 +24,7 @@ from canonicaljson import json
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.storage.events import EventsWorkerStore
|
from synapse.storage.events_worker import EventsWorkerStore
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.caches import intern_string
|
from synapse.util.caches import intern_string
|
||||||
|
@ -74,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore):
|
|||||||
txn (cursor):
|
txn (cursor):
|
||||||
event_id (str): Id for the Event.
|
event_id (str): Id for the Event.
|
||||||
Returns:
|
Returns:
|
||||||
A dict of algorithm -> hash.
|
A dict[unicode, bytes] of algorithm -> hash.
|
||||||
"""
|
"""
|
||||||
query = (
|
query = (
|
||||||
"SELECT algorithm, hash"
|
"SELECT algorithm, hash"
|
||||||
|
@ -43,7 +43,7 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.events import EventsWorkerStore
|
from synapse.storage.events_worker import EventsWorkerStore
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import RoomStreamToken
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||||
|
@ -137,7 +137,7 @@ class DomainSpecificString(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_string(cls, s):
|
def from_string(cls, s):
|
||||||
"""Parse the string given by 's' into a structure object."""
|
"""Parse the string given by 's' into a structure object."""
|
||||||
if len(s) < 1 or s[0] != cls.SIGIL:
|
if len(s) < 1 or s[0:1] != cls.SIGIL:
|
||||||
raise SynapseError(400, "Expected %s string to start with '%s'" % (
|
raise SynapseError(400, "Expected %s string to start with '%s'" % (
|
||||||
cls.__name__, cls.SIGIL,
|
cls.__name__, cls.SIGIL,
|
||||||
))
|
))
|
||||||
|
@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
|||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
# If we're passed a cache_context then we'll want to call its
|
||||||
# whenever we are invalidated
|
# invalidate() whenever we are invalidated
|
||||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||||
|
|
||||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||||
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
||||||
list_args = arg_dict[self.list_name]
|
list_args = arg_dict[self.list_name]
|
||||||
|
|
||||||
# cached is a dict arg -> deferred, where deferred results in a
|
|
||||||
# 2-tuple (`arg`, `result`)
|
|
||||||
results = {}
|
results = {}
|
||||||
cached_defers = {}
|
|
||||||
missing = []
|
def update_results_dict(res, arg):
|
||||||
|
results[arg] = res
|
||||||
|
|
||||||
|
# list of deferreds to wait for
|
||||||
|
cached_defers = []
|
||||||
|
|
||||||
|
missing = set()
|
||||||
|
|
||||||
# If the cache takes a single arg then that is used as the key,
|
# If the cache takes a single arg then that is used as the key,
|
||||||
# otherwise a tuple is used.
|
# otherwise a tuple is used.
|
||||||
if num_args == 1:
|
if num_args == 1:
|
||||||
def cache_get(arg):
|
def arg_to_cache_key(arg):
|
||||||
return cache.get(arg, callback=invalidate_callback)
|
return arg
|
||||||
else:
|
else:
|
||||||
key = list(keyargs)
|
keylist = list(keyargs)
|
||||||
|
|
||||||
def cache_get(arg):
|
def arg_to_cache_key(arg):
|
||||||
key[self.list_pos] = arg
|
keylist[self.list_pos] = arg
|
||||||
return cache.get(tuple(key), callback=invalidate_callback)
|
return tuple(keylist)
|
||||||
|
|
||||||
for arg in list_args:
|
for arg in list_args:
|
||||||
try:
|
try:
|
||||||
res = cache_get(arg)
|
res = cache.get(arg_to_cache_key(arg),
|
||||||
|
callback=invalidate_callback)
|
||||||
if not isinstance(res, ObservableDeferred):
|
if not isinstance(res, ObservableDeferred):
|
||||||
results[arg] = res
|
results[arg] = res
|
||||||
elif not res.has_succeeded():
|
elif not res.has_succeeded():
|
||||||
res = res.observe()
|
res = res.observe()
|
||||||
res.addCallback(lambda r, arg: (arg, r), arg)
|
res.addCallback(update_results_dict, arg)
|
||||||
cached_defers[arg] = res
|
cached_defers.append(res)
|
||||||
else:
|
else:
|
||||||
results[arg] = res.get_result()
|
results[arg] = res.get_result()
|
||||||
except KeyError:
|
except KeyError:
|
||||||
missing.append(arg)
|
missing.add(arg)
|
||||||
|
|
||||||
if missing:
|
if missing:
|
||||||
args_to_call = dict(arg_dict)
|
# we need an observable deferred for each entry in the list,
|
||||||
args_to_call[self.list_name] = missing
|
# which we put in the cache. Each deferred resolves with the
|
||||||
|
# relevant result for that key.
|
||||||
|
deferreds_map = {}
|
||||||
|
for arg in missing:
|
||||||
|
deferred = defer.Deferred()
|
||||||
|
deferreds_map[arg] = deferred
|
||||||
|
key = arg_to_cache_key(arg)
|
||||||
|
observable = ObservableDeferred(deferred)
|
||||||
|
cache.set(key, observable, callback=invalidate_callback)
|
||||||
|
|
||||||
ret_d = defer.maybeDeferred(
|
def complete_all(res):
|
||||||
|
# the wrapped function has completed. It returns a
|
||||||
|
# a dict. We can now resolve the observable deferreds in
|
||||||
|
# the cache and update our own result map.
|
||||||
|
for e in missing:
|
||||||
|
val = res.get(e, None)
|
||||||
|
deferreds_map[e].callback(val)
|
||||||
|
results[e] = val
|
||||||
|
|
||||||
|
def errback(f):
|
||||||
|
# the wrapped function has failed. Invalidate any cache
|
||||||
|
# entries we're supposed to be populating, and fail
|
||||||
|
# their deferreds.
|
||||||
|
for e in missing:
|
||||||
|
key = arg_to_cache_key(e)
|
||||||
|
cache.invalidate(key)
|
||||||
|
deferreds_map[e].errback(f)
|
||||||
|
|
||||||
|
# return the failure, to propagate to our caller.
|
||||||
|
return f
|
||||||
|
|
||||||
|
args_to_call = dict(arg_dict)
|
||||||
|
args_to_call[self.list_name] = list(missing)
|
||||||
|
|
||||||
|
cached_defers.append(defer.maybeDeferred(
|
||||||
logcontext.preserve_fn(self.function_to_call),
|
logcontext.preserve_fn(self.function_to_call),
|
||||||
**args_to_call
|
**args_to_call
|
||||||
)
|
).addCallbacks(complete_all, errback))
|
||||||
|
|
||||||
ret_d = ObservableDeferred(ret_d)
|
|
||||||
|
|
||||||
# We need to create deferreds for each arg in the list so that
|
|
||||||
# we can insert the new deferred into the cache.
|
|
||||||
for arg in missing:
|
|
||||||
observer = ret_d.observe()
|
|
||||||
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
|
|
||||||
|
|
||||||
observer = ObservableDeferred(observer)
|
|
||||||
|
|
||||||
if num_args == 1:
|
|
||||||
cache.set(
|
|
||||||
arg, observer,
|
|
||||||
callback=invalidate_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
def invalidate(f, key):
|
|
||||||
cache.invalidate(key)
|
|
||||||
return f
|
|
||||||
observer.addErrback(invalidate, arg)
|
|
||||||
else:
|
|
||||||
key = list(keyargs)
|
|
||||||
key[self.list_pos] = arg
|
|
||||||
cache.set(
|
|
||||||
tuple(key), observer,
|
|
||||||
callback=invalidate_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
def invalidate(f, key):
|
|
||||||
cache.invalidate(key)
|
|
||||||
return f
|
|
||||||
observer.addErrback(invalidate, tuple(key))
|
|
||||||
|
|
||||||
res = observer.observe()
|
|
||||||
res.addCallback(lambda r, arg: (arg, r), arg)
|
|
||||||
|
|
||||||
cached_defers[arg] = res
|
|
||||||
|
|
||||||
if cached_defers:
|
if cached_defers:
|
||||||
def update_results_dict(res):
|
d = defer.gatherResults(
|
||||||
results.update(res)
|
cached_defers,
|
||||||
return results
|
|
||||||
|
|
||||||
return logcontext.make_deferred_yieldable(defer.gatherResults(
|
|
||||||
list(cached_defers.values()),
|
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addCallback(update_results_dict).addErrback(
|
).addCallbacks(
|
||||||
|
lambda _: results,
|
||||||
unwrapFirstError
|
unwrapFirstError
|
||||||
))
|
)
|
||||||
|
return logcontext.make_deferred_yieldable(d)
|
||||||
else:
|
else:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
|
|||||||
cache.
|
cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache (Cache): The underlying cache to use.
|
cached_method_name (str): The name of the single-item lookup method.
|
||||||
|
This is only used to find the cache to use.
|
||||||
list_name (str): The name of the argument that is the list to use to
|
list_name (str): The name of the argument that is the list to use to
|
||||||
do batch lookups in the cache.
|
do batch lookups in the cache.
|
||||||
num_args (int): Number of arguments to use as the key in the cache
|
num_args (int): Number of arguments to use as the key in the cache
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from six import string_types
|
from six import binary_type, text_type
|
||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
@ -26,7 +26,7 @@ def freeze(o):
|
|||||||
if isinstance(o, frozendict):
|
if isinstance(o, frozendict):
|
||||||
return o
|
return o
|
||||||
|
|
||||||
if isinstance(o, string_types):
|
if isinstance(o, (binary_type, text_type)):
|
||||||
return o
|
return o
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -41,7 +41,7 @@ def unfreeze(o):
|
|||||||
if isinstance(o, (dict, frozendict)):
|
if isinstance(o, (dict, frozendict)):
|
||||||
return dict({k: unfreeze(v) for k, v in o.items()})
|
return dict({k: unfreeze(v) for k, v in o.items()})
|
||||||
|
|
||||||
if isinstance(o, string_types):
|
if isinstance(o, (binary_type, text_type)):
|
||||||
return o
|
return o
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -46,7 +46,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
self.auth = Auth(self.hs)
|
self.auth = Auth(self.hs)
|
||||||
|
|
||||||
self.test_user = "@foo:bar"
|
self.test_user = "@foo:bar"
|
||||||
self.test_token = "_test_token_"
|
self.test_token = b"_test_token_"
|
||||||
|
|
||||||
# this is overridden for the appservice tests
|
# this is overridden for the appservice tests
|
||||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||||
@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||||
@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = self.auth.get_user_by_req(request)
|
||||||
self.failureResultOf(d, AuthError)
|
self.failureResultOf(d, AuthError)
|
||||||
@ -98,7 +98,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientIP.return_value = "127.0.0.1"
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||||
@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "192.168.10.10"
|
request.getClientIP.return_value = "192.168.10.10"
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||||
@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "131.111.8.42"
|
request.getClientIP.return_value = "131.111.8.42"
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = self.auth.get_user_by_req(request)
|
||||||
self.failureResultOf(d, AuthError)
|
self.failureResultOf(d, AuthError)
|
||||||
@ -141,7 +141,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = self.auth.get_user_by_req(request)
|
||||||
self.failureResultOf(d, AuthError)
|
self.failureResultOf(d, AuthError)
|
||||||
@ -158,7 +158,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
|
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
|
||||||
masquerading_user_id = "@doppelganger:matrix.org"
|
masquerading_user_id = b"@doppelganger:matrix.org"
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
token="foobar", url="a_url", sender=self.test_user,
|
token="foobar", url="a_url", sender=self.test_user,
|
||||||
ip_range_whitelist=None,
|
ip_range_whitelist=None,
|
||||||
@ -169,14 +169,17 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientIP.return_value = "127.0.0.1"
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.args["user_id"] = [masquerading_user_id]
|
request.args[b"user_id"] = [masquerading_user_id]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(requester.user.to_string(), masquerading_user_id)
|
self.assertEquals(
|
||||||
|
requester.user.to_string(),
|
||||||
|
masquerading_user_id.decode('utf8')
|
||||||
|
)
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
||||||
masquerading_user_id = "@doppelganger:matrix.org"
|
masquerading_user_id = b"@doppelganger:matrix.org"
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
token="foobar", url="a_url", sender=self.test_user,
|
token="foobar", url="a_url", sender=self.test_user,
|
||||||
ip_range_whitelist=None,
|
ip_range_whitelist=None,
|
||||||
@ -187,8 +190,8 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientIP.return_value = "127.0.0.1"
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.args["user_id"] = [masquerading_user_id]
|
request.args[b"user_id"] = [masquerading_user_id]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = self.auth.get_user_by_req(request)
|
||||||
self.failureResultOf(d, AuthError)
|
self.failureResultOf(d, AuthError)
|
||||||
@ -418,7 +421,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
# check the token works
|
# check the token works
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [token]
|
request.args[b"access_token"] = [token.encode('ascii')]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
self.assertEqual(UserID.from_string(USER_ID), requester.user)
|
self.assertEqual(UserID.from_string(USER_ID), requester.user)
|
||||||
@ -431,7 +434,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
# the token should *not* work now
|
# the token should *not* work now
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [guest_tok]
|
request.args[b"access_token"] = [guest_tok.encode('ascii')]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
|
|
||||||
with self.assertRaises(AuthError) as cm:
|
with self.assertRaises(AuthError) as cm:
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
|
||||||
@ -19,6 +20,7 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
import synapse
|
import synapse
|
||||||
import synapse.api.errors
|
import synapse.api.errors
|
||||||
|
from synapse.api.errors import AuthError
|
||||||
from synapse.handlers.auth import AuthHandler
|
from synapse.handlers.auth import AuthHandler
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
@ -37,6 +39,10 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
self.hs.handlers = AuthHandlers(self.hs)
|
self.hs.handlers = AuthHandlers(self.hs)
|
||||||
self.auth_handler = self.hs.handlers.auth_handler
|
self.auth_handler = self.hs.handlers.auth_handler
|
||||||
self.macaroon_generator = self.hs.get_macaroon_generator()
|
self.macaroon_generator = self.hs.get_macaroon_generator()
|
||||||
|
# MAU tests
|
||||||
|
self.hs.config.max_mau_value = 50
|
||||||
|
self.small_number_of_users = 1
|
||||||
|
self.large_number_of_users = 100
|
||||||
|
|
||||||
def test_token_is_a_macaroon(self):
|
def test_token_is_a_macaroon(self):
|
||||||
token = self.macaroon_generator.generate_access_token("some_user")
|
token = self.macaroon_generator.generate_access_token("some_user")
|
||||||
@ -71,38 +77,37 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
v.satisfy_general(verify_nonce)
|
v.satisfy_general(verify_nonce)
|
||||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def test_short_term_login_token_gives_user_id(self):
|
def test_short_term_login_token_gives_user_id(self):
|
||||||
self.hs.clock.now = 1000
|
self.hs.clock.now = 1000
|
||||||
|
|
||||||
token = self.macaroon_generator.generate_short_term_login_token(
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
"a_user", 5000
|
"a_user", 5000
|
||||||
)
|
)
|
||||||
|
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
self.assertEqual(
|
token
|
||||||
"a_user",
|
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
|
||||||
token
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
self.assertEqual("a_user", user_id)
|
||||||
|
|
||||||
# when we advance the clock, the token should be rejected
|
# when we advance the clock, the token should be rejected
|
||||||
self.hs.clock.now = 6000
|
self.hs.clock.now = 6000
|
||||||
with self.assertRaises(synapse.api.errors.AuthError):
|
with self.assertRaises(synapse.api.errors.AuthError):
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
token
|
token
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def test_short_term_login_token_cannot_replace_user_id(self):
|
def test_short_term_login_token_cannot_replace_user_id(self):
|
||||||
token = self.macaroon_generator.generate_short_term_login_token(
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
"a_user", 5000
|
"a_user", 5000
|
||||||
)
|
)
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
|
||||||
|
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
|
macaroon.serialize()
|
||||||
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"a_user",
|
"a_user", user_id
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
|
||||||
macaroon.serialize()
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# add another "user_id" caveat, which might allow us to override the
|
# add another "user_id" caveat, which might allow us to override the
|
||||||
@ -110,6 +115,57 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
macaroon.add_first_party_caveat("user_id = b_user")
|
macaroon.add_first_party_caveat("user_id = b_user")
|
||||||
|
|
||||||
with self.assertRaises(synapse.api.errors.AuthError):
|
with self.assertRaises(synapse.api.errors.AuthError):
|
||||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
macaroon.serialize()
|
macaroon.serialize()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_mau_limits_disabled(self):
|
||||||
|
self.hs.config.limit_usage_by_mau = False
|
||||||
|
# Ensure does not throw exception
|
||||||
|
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
||||||
|
|
||||||
|
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
|
self._get_macaroon().serialize()
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_mau_limits_exceeded(self):
|
||||||
|
self.hs.config.limit_usage_by_mau = True
|
||||||
|
self.hs.get_datastore().count_monthly_users = Mock(
|
||||||
|
return_value=defer.succeed(self.large_number_of_users)
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(AuthError):
|
||||||
|
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
||||||
|
|
||||||
|
self.hs.get_datastore().count_monthly_users = Mock(
|
||||||
|
return_value=defer.succeed(self.large_number_of_users)
|
||||||
|
)
|
||||||
|
with self.assertRaises(AuthError):
|
||||||
|
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
|
self._get_macaroon().serialize()
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_mau_limits_not_exceeded(self):
|
||||||
|
self.hs.config.limit_usage_by_mau = True
|
||||||
|
|
||||||
|
self.hs.get_datastore().count_monthly_users = Mock(
|
||||||
|
return_value=defer.succeed(self.small_number_of_users)
|
||||||
|
)
|
||||||
|
# Ensure does not raise exception
|
||||||
|
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
||||||
|
|
||||||
|
self.hs.get_datastore().count_monthly_users = Mock(
|
||||||
|
return_value=defer.succeed(self.small_number_of_users)
|
||||||
|
)
|
||||||
|
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
|
self._get_macaroon().serialize()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_macaroon(self):
|
||||||
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
|
"user_a", 5000
|
||||||
|
)
|
||||||
|
return pymacaroons.Macaroon.deserialize(token)
|
||||||
|
@ -17,6 +17,7 @@ from mock import Mock
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import RegistrationError
|
||||||
from synapse.handlers.register import RegistrationHandler
|
from synapse.handlers.register import RegistrationHandler
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.types import UserID, create_requester
|
||||||
|
|
||||||
@ -77,3 +78,53 @@ class RegistrationTestCase(unittest.TestCase):
|
|||||||
requester, local_part, display_name)
|
requester, local_part, display_name)
|
||||||
self.assertEquals(result_user_id, user_id)
|
self.assertEquals(result_user_id, user_id)
|
||||||
self.assertEquals(result_token, 'secret')
|
self.assertEquals(result_token, 'secret')
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_cannot_register_when_mau_limits_exceeded(self):
|
||||||
|
local_part = "someone"
|
||||||
|
display_name = "someone"
|
||||||
|
requester = create_requester("@as:test")
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
self.hs.config.limit_usage_by_mau = False
|
||||||
|
self.hs.config.max_mau_value = 50
|
||||||
|
lots_of_users = 100
|
||||||
|
small_number_users = 1
|
||||||
|
|
||||||
|
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
|
||||||
|
|
||||||
|
# Ensure does not throw exception
|
||||||
|
yield self.handler.get_or_create_user(requester, 'a', display_name)
|
||||||
|
|
||||||
|
self.hs.config.limit_usage_by_mau = True
|
||||||
|
|
||||||
|
with self.assertRaises(RegistrationError):
|
||||||
|
yield self.handler.get_or_create_user(requester, 'b', display_name)
|
||||||
|
|
||||||
|
store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users))
|
||||||
|
|
||||||
|
self._macaroon_mock_generator("another_secret")
|
||||||
|
|
||||||
|
# Ensure does not throw exception
|
||||||
|
yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil")
|
||||||
|
|
||||||
|
self._macaroon_mock_generator("another another secret")
|
||||||
|
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
|
||||||
|
|
||||||
|
with self.assertRaises(RegistrationError):
|
||||||
|
yield self.handler.register(localpart=local_part)
|
||||||
|
|
||||||
|
self._macaroon_mock_generator("another another secret")
|
||||||
|
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
|
||||||
|
|
||||||
|
with self.assertRaises(RegistrationError):
|
||||||
|
yield self.handler.register_saml2(local_part)
|
||||||
|
|
||||||
|
def _macaroon_mock_generator(self, secret):
|
||||||
|
"""
|
||||||
|
Reset macaroon generator in the case where the test creates multiple users
|
||||||
|
"""
|
||||||
|
macaroon_generator = Mock(
|
||||||
|
generate_access_token=Mock(return_value=secret))
|
||||||
|
self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator)
|
||||||
|
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||||
|
self.handler = self.hs.get_handlers().registration_handler
|
||||||
|
@ -44,7 +44,6 @@ def _expect_edu(destination, edu_type, content, origin="test"):
|
|||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"pdu_failures": [],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
65
tests/storage/test__init__.py
Normal file
65
tests/storage/test__init__.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import tests.utils
|
||||||
|
|
||||||
|
|
||||||
|
class InitTestCase(tests.unittest.TestCase):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(InitTestCase, self).__init__(*args, **kwargs)
|
||||||
|
self.store = None # type: synapse.storage.DataStore
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def setUp(self):
|
||||||
|
hs = yield tests.utils.setup_test_homeserver()
|
||||||
|
|
||||||
|
hs.config.max_mau_value = 50
|
||||||
|
hs.config.limit_usage_by_mau = True
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_count_monthly_users(self):
|
||||||
|
count = yield self.store.count_monthly_users()
|
||||||
|
self.assertEqual(0, count)
|
||||||
|
|
||||||
|
yield self._insert_user_ips("@user:server1")
|
||||||
|
yield self._insert_user_ips("@user:server2")
|
||||||
|
|
||||||
|
count = yield self.store.count_monthly_users()
|
||||||
|
self.assertEqual(2, count)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _insert_user_ips(self, user):
|
||||||
|
"""
|
||||||
|
Helper function to populate user_ips without using batch insertion infra
|
||||||
|
args:
|
||||||
|
user (str): specify username i.e. @user:server.com
|
||||||
|
"""
|
||||||
|
yield self.store._simple_upsert(
|
||||||
|
table="user_ips",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user,
|
||||||
|
"access_token": "access_token",
|
||||||
|
"ip": "ip",
|
||||||
|
"user_agent": "user_agent",
|
||||||
|
"device_id": "device_id",
|
||||||
|
},
|
||||||
|
values={
|
||||||
|
"last_seen": self.clock.time_msec(),
|
||||||
|
}
|
||||||
|
)
|
@ -273,3 +273,104 @@ class DescriptorTestCase(unittest.TestCase):
|
|||||||
r = yield obj.fn(2, 3)
|
r = yield obj.fn(2, 3)
|
||||||
self.assertEqual(r, 'chips')
|
self.assertEqual(r, 'chips')
|
||||||
obj.mock.assert_not_called()
|
obj.mock.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_cache(self):
|
||||||
|
class Cls(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
@descriptors.cached()
|
||||||
|
def fn(self, arg1, arg2):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
|
||||||
|
def list_fn(self, args1, arg2):
|
||||||
|
assert (
|
||||||
|
logcontext.LoggingContext.current_context().request == "c1"
|
||||||
|
)
|
||||||
|
# we want this to behave like an asynchronous function
|
||||||
|
yield run_on_reactor()
|
||||||
|
assert (
|
||||||
|
logcontext.LoggingContext.current_context().request == "c1"
|
||||||
|
)
|
||||||
|
defer.returnValue(self.mock(args1, arg2))
|
||||||
|
|
||||||
|
with logcontext.LoggingContext() as c1:
|
||||||
|
c1.request = "c1"
|
||||||
|
obj = Cls()
|
||||||
|
obj.mock.return_value = {10: 'fish', 20: 'chips'}
|
||||||
|
d1 = obj.list_fn([10, 20], 2)
|
||||||
|
self.assertEqual(
|
||||||
|
logcontext.LoggingContext.current_context(),
|
||||||
|
logcontext.LoggingContext.sentinel,
|
||||||
|
)
|
||||||
|
r = yield d1
|
||||||
|
self.assertEqual(
|
||||||
|
logcontext.LoggingContext.current_context(),
|
||||||
|
c1
|
||||||
|
)
|
||||||
|
obj.mock.assert_called_once_with([10, 20], 2)
|
||||||
|
self.assertEqual(r, {10: 'fish', 20: 'chips'})
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# a call with different params should call the mock again
|
||||||
|
obj.mock.return_value = {30: 'peas'}
|
||||||
|
r = yield obj.list_fn([20, 30], 2)
|
||||||
|
obj.mock.assert_called_once_with([30], 2)
|
||||||
|
self.assertEqual(r, {20: 'chips', 30: 'peas'})
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# all the values should now be cached
|
||||||
|
r = yield obj.fn(10, 2)
|
||||||
|
self.assertEqual(r, 'fish')
|
||||||
|
r = yield obj.fn(20, 2)
|
||||||
|
self.assertEqual(r, 'chips')
|
||||||
|
r = yield obj.fn(30, 2)
|
||||||
|
self.assertEqual(r, 'peas')
|
||||||
|
r = yield obj.list_fn([10, 20, 30], 2)
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_invalidate(self):
|
||||||
|
"""Make sure that invalidation callbacks are called."""
|
||||||
|
class Cls(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
@descriptors.cached()
|
||||||
|
def fn(self, arg1, arg2):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
|
||||||
|
def list_fn(self, args1, arg2):
|
||||||
|
# we want this to behave like an asynchronous function
|
||||||
|
yield run_on_reactor()
|
||||||
|
defer.returnValue(self.mock(args1, arg2))
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
invalidate0 = mock.Mock()
|
||||||
|
invalidate1 = mock.Mock()
|
||||||
|
|
||||||
|
# cache miss
|
||||||
|
obj.mock.return_value = {10: 'fish', 20: 'chips'}
|
||||||
|
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
|
||||||
|
obj.mock.assert_called_once_with([10, 20], 2)
|
||||||
|
self.assertEqual(r1, {10: 'fish', 20: 'chips'})
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# cache hit
|
||||||
|
r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
self.assertEqual(r2, {10: 'fish', 20: 'chips'})
|
||||||
|
|
||||||
|
invalidate0.assert_not_called()
|
||||||
|
invalidate1.assert_not_called()
|
||||||
|
|
||||||
|
# now if we invalidate the keys, both invalidations should get called
|
||||||
|
obj.fn.invalidate((10, 2))
|
||||||
|
invalidate0.assert_called_once()
|
||||||
|
invalidate1.assert_called_once()
|
||||||
|
@ -193,7 +193,7 @@ class MockHttpResource(HttpServer):
|
|||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
def trigger_get(self, path):
|
def trigger_get(self, path):
|
||||||
return self.trigger("GET", path, None)
|
return self.trigger(b"GET", path, None)
|
||||||
|
|
||||||
@patch('twisted.web.http.Request')
|
@patch('twisted.web.http.Request')
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -227,7 +227,7 @@ class MockHttpResource(HttpServer):
|
|||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
if federation_auth:
|
if federation_auth:
|
||||||
headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="]
|
headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
|
||||||
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
|
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
|
||||||
|
|
||||||
# return the right path if the event requires it
|
# return the right path if the event requires it
|
||||||
@ -241,6 +241,9 @@ class MockHttpResource(HttpServer):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if isinstance(path, bytes):
|
||||||
|
path = path.decode('utf8')
|
||||||
|
|
||||||
for (method, pattern, func) in self.callbacks:
|
for (method, pattern, func) in self.callbacks:
|
||||||
if http_method != method:
|
if http_method != method:
|
||||||
continue
|
continue
|
||||||
@ -249,7 +252,7 @@ class MockHttpResource(HttpServer):
|
|||||||
if matcher:
|
if matcher:
|
||||||
try:
|
try:
|
||||||
args = [
|
args = [
|
||||||
urlparse.unquote(u).decode("UTF-8")
|
urlparse.unquote(u)
|
||||||
for u in matcher.groups()
|
for u in matcher.groups()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user