mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-12-27 04:59:23 -05:00
Merge branch 'develop' into matthew/gin_work_mem
This commit is contained in:
commit
bb9f0f3cdb
53
CHANGES.rst
53
CHANGES.rst
@ -1,3 +1,56 @@
|
|||||||
|
Unreleased
|
||||||
|
==========
|
||||||
|
|
||||||
|
synctl no longer starts the main synapse when using ``-a`` option with workers.
|
||||||
|
A new worker file should be added with ``worker_app: synapse.app.homeserver``.
|
||||||
|
|
||||||
|
This release also begins the process of renaming a number of the metrics
|
||||||
|
reported to prometheus. See `docs/metrics-howto.rst <docs/metrics-howto.rst#block-and-response-metrics-renamed-for-0-27-0>`_.
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.26.0 (2018-01-05)
|
||||||
|
=======================================
|
||||||
|
|
||||||
|
No changes since v0.26.0-rc1
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.26.0-rc1 (2017-12-13)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Features:
|
||||||
|
|
||||||
|
* Add ability for ASes to publicise groups for their users (PR #2686)
|
||||||
|
* Add all local users to the user_directory and optionally search them (PR
|
||||||
|
#2723)
|
||||||
|
* Add support for custom login types for validating users (PR #2729)
|
||||||
|
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Update example Prometheus config to new format (PR #2648) Thanks to
|
||||||
|
@krombel!
|
||||||
|
* Rename redact_content option to include_content in Push API (PR #2650)
|
||||||
|
* Declare support for r0.3.0 (PR #2677)
|
||||||
|
* Improve upserts (PR #2684, #2688, #2689, #2713)
|
||||||
|
* Improve documentation of workers (PR #2700)
|
||||||
|
* Improve tracebacks on exceptions (PR #2705)
|
||||||
|
* Allow guest access to group APIs for reading (PR #2715)
|
||||||
|
* Support for posting content in federation_client script (PR #2716)
|
||||||
|
* Delete devices and pushers on logouts etc (PR #2722)
|
||||||
|
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix database port script (PR #2673)
|
||||||
|
* Fix internal server error on login with ldap_auth_provider (PR #2678) Thanks
|
||||||
|
to @jkolo!
|
||||||
|
* Fix error on sqlite 3.7 (PR #2697)
|
||||||
|
* Fix OPTIONS on preview_url (PR #2707)
|
||||||
|
* Fix error handling on dns lookup (PR #2711)
|
||||||
|
* Fix wrong avatars when inviting multiple users when creating room (PR #2717)
|
||||||
|
* Fix 500 when joining matrix-dev (PR #2719)
|
||||||
|
|
||||||
|
|
||||||
Changes in synapse v0.25.1 (2017-11-17)
|
Changes in synapse v0.25.1 (2017-11-17)
|
||||||
=======================================
|
=======================================
|
||||||
|
|
||||||
|
@ -632,6 +632,11 @@ largest boxes pause for thought.)
|
|||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
|
You can use the federation tester to check if your homeserver is all set:
|
||||||
|
``https://matrix.org/federationtester/api/report?server_name=<your_server_name>``
|
||||||
|
If any of the attributes under "checks" is false, federation won't work.
|
||||||
|
|
||||||
The typical failure mode with federation is that when you try to join a room,
|
The typical failure mode with federation is that when you try to join a room,
|
||||||
it is rejected with "401: Unauthorized". Generally this means that other
|
it is rejected with "401: Unauthorized". Generally this means that other
|
||||||
servers in the room couldn't access yours. (Joining a room over federation is a
|
servers in the room couldn't access yours. (Joining a room over federation is a
|
||||||
|
@ -33,6 +33,53 @@ How to monitor Synapse metrics using Prometheus
|
|||||||
|
|
||||||
Restart prometheus.
|
Restart prometheus.
|
||||||
|
|
||||||
|
|
||||||
|
Block and response metrics renamed for 0.27.0
|
||||||
|
---------------------------------------------
|
||||||
|
|
||||||
|
Synapse 0.27.0 begins the process of rationalising the duplicate ``*:count``
|
||||||
|
metrics reported for the resource tracking for code blocks and HTTP requests.
|
||||||
|
|
||||||
|
At the same time, the corresponding ``*:total`` metrics are being renamed, as
|
||||||
|
the ``:total`` suffix no longer makes sense in the absence of a corresponding
|
||||||
|
``:count`` metric.
|
||||||
|
|
||||||
|
To enable a graceful migration path, this release just adds new names for the
|
||||||
|
metrics being renamed. A future release will remove the old ones.
|
||||||
|
|
||||||
|
The following table shows the new metrics, and the old metrics which they are
|
||||||
|
replacing.
|
||||||
|
|
||||||
|
==================================================== ===================================================
|
||||||
|
New name Old name
|
||||||
|
==================================================== ===================================================
|
||||||
|
synapse_util_metrics_block_count synapse_util_metrics_block_timer:count
|
||||||
|
synapse_util_metrics_block_count synapse_util_metrics_block_ru_utime:count
|
||||||
|
synapse_util_metrics_block_count synapse_util_metrics_block_ru_stime:count
|
||||||
|
synapse_util_metrics_block_count synapse_util_metrics_block_db_txn_count:count
|
||||||
|
synapse_util_metrics_block_count synapse_util_metrics_block_db_txn_duration:count
|
||||||
|
|
||||||
|
synapse_util_metrics_block_time_seconds synapse_util_metrics_block_timer:total
|
||||||
|
synapse_util_metrics_block_ru_utime_seconds synapse_util_metrics_block_ru_utime:total
|
||||||
|
synapse_util_metrics_block_ru_stime_seconds synapse_util_metrics_block_ru_stime:total
|
||||||
|
synapse_util_metrics_block_db_txn_count synapse_util_metrics_block_db_txn_count:total
|
||||||
|
synapse_util_metrics_block_db_txn_duration_seconds synapse_util_metrics_block_db_txn_duration:total
|
||||||
|
|
||||||
|
synapse_http_server_response_count synapse_http_server_requests
|
||||||
|
synapse_http_server_response_count synapse_http_server_response_time:count
|
||||||
|
synapse_http_server_response_count synapse_http_server_response_ru_utime:count
|
||||||
|
synapse_http_server_response_count synapse_http_server_response_ru_stime:count
|
||||||
|
synapse_http_server_response_count synapse_http_server_response_db_txn_count:count
|
||||||
|
synapse_http_server_response_count synapse_http_server_response_db_txn_duration:count
|
||||||
|
|
||||||
|
synapse_http_server_response_time_seconds synapse_http_server_response_time:total
|
||||||
|
synapse_http_server_response_ru_utime_seconds synapse_http_server_response_ru_utime:total
|
||||||
|
synapse_http_server_response_ru_stime_seconds synapse_http_server_response_ru_stime:total
|
||||||
|
synapse_http_server_response_db_txn_count synapse_http_server_response_db_txn_count:total
|
||||||
|
synapse_http_server_response_db_txn_duration_seconds synapse_http_server_response_db_txn_duration:total
|
||||||
|
==================================================== ===================================================
|
||||||
|
|
||||||
|
|
||||||
Standard Metric Names
|
Standard Metric Names
|
||||||
---------------------
|
---------------------
|
||||||
|
|
||||||
@ -42,7 +89,7 @@ have been changed to seconds, from miliseconds.
|
|||||||
|
|
||||||
================================== =============================
|
================================== =============================
|
||||||
New name Old name
|
New name Old name
|
||||||
---------------------------------- -----------------------------
|
================================== =============================
|
||||||
process_cpu_user_seconds_total process_resource_utime / 1000
|
process_cpu_user_seconds_total process_resource_utime / 1000
|
||||||
process_cpu_system_seconds_total process_resource_stime / 1000
|
process_cpu_system_seconds_total process_resource_stime / 1000
|
||||||
process_open_fds (no 'type' label) process_fds
|
process_open_fds (no 'type' label) process_fds
|
||||||
@ -52,7 +99,7 @@ The python-specific counts of garbage collector performance have been renamed.
|
|||||||
|
|
||||||
=========================== ======================
|
=========================== ======================
|
||||||
New name Old name
|
New name Old name
|
||||||
--------------------------- ----------------------
|
=========================== ======================
|
||||||
python_gc_time reactor_gc_time
|
python_gc_time reactor_gc_time
|
||||||
python_gc_unreachable_total reactor_gc_unreachable
|
python_gc_unreachable_total reactor_gc_unreachable
|
||||||
python_gc_counts reactor_gc_counts
|
python_gc_counts reactor_gc_counts
|
||||||
@ -62,7 +109,7 @@ The twisted-specific reactor metrics have been renamed.
|
|||||||
|
|
||||||
==================================== =====================
|
==================================== =====================
|
||||||
New name Old name
|
New name Old name
|
||||||
------------------------------------ ---------------------
|
==================================== =====================
|
||||||
python_twisted_reactor_pending_calls reactor_pending_calls
|
python_twisted_reactor_pending_calls reactor_pending_calls
|
||||||
python_twisted_reactor_tick_time reactor_tick_time
|
python_twisted_reactor_tick_time reactor_tick_time
|
||||||
==================================== =====================
|
==================================== =====================
|
||||||
|
133
scripts/move_remote_media_to_new_store.py
Executable file
133
scripts/move_remote_media_to_new_store.py
Executable file
@ -0,0 +1,133 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Moves a list of remote media from one media store to another.
|
||||||
|
|
||||||
|
The input should be a list of media files to be moved, one per line. Each line
|
||||||
|
should be formatted::
|
||||||
|
|
||||||
|
<origin server>|<file id>
|
||||||
|
|
||||||
|
This can be extracted from postgres with::
|
||||||
|
|
||||||
|
psql --tuples-only -A -c "select media_origin, filesystem_id from
|
||||||
|
matrix.remote_media_cache where ..."
|
||||||
|
|
||||||
|
To use, pipe the above into::
|
||||||
|
|
||||||
|
PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py <source repo> <dest repo>
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from synapse.rest.media.v1.filepath import MediaFilePaths
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def main(src_repo, dest_repo):
|
||||||
|
src_paths = MediaFilePaths(src_repo)
|
||||||
|
dest_paths = MediaFilePaths(dest_repo)
|
||||||
|
for line in sys.stdin:
|
||||||
|
line = line.strip()
|
||||||
|
parts = line.split('|')
|
||||||
|
if len(parts) != 2:
|
||||||
|
print("Unable to parse input line %s" % line, file=sys.stderr)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
move_media(parts[0], parts[1], src_paths, dest_paths)
|
||||||
|
|
||||||
|
|
||||||
|
def move_media(origin_server, file_id, src_paths, dest_paths):
|
||||||
|
"""Move the given file, and any thumbnails, to the dest repo
|
||||||
|
|
||||||
|
Args:
|
||||||
|
origin_server (str):
|
||||||
|
file_id (str):
|
||||||
|
src_paths (MediaFilePaths):
|
||||||
|
dest_paths (MediaFilePaths):
|
||||||
|
"""
|
||||||
|
logger.info("%s/%s", origin_server, file_id)
|
||||||
|
|
||||||
|
# check that the original exists
|
||||||
|
original_file = src_paths.remote_media_filepath(origin_server, file_id)
|
||||||
|
if not os.path.exists(original_file):
|
||||||
|
logger.warn(
|
||||||
|
"Original for %s/%s (%s) does not exist",
|
||||||
|
origin_server, file_id, original_file,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mkdir_and_move(
|
||||||
|
original_file,
|
||||||
|
dest_paths.remote_media_filepath(origin_server, file_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
# now look for thumbnails
|
||||||
|
original_thumb_dir = src_paths.remote_media_thumbnail_dir(
|
||||||
|
origin_server, file_id,
|
||||||
|
)
|
||||||
|
if not os.path.exists(original_thumb_dir):
|
||||||
|
return
|
||||||
|
|
||||||
|
mkdir_and_move(
|
||||||
|
original_thumb_dir,
|
||||||
|
dest_paths.remote_media_thumbnail_dir(origin_server, file_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mkdir_and_move(original_file, dest_file):
|
||||||
|
dirname = os.path.dirname(dest_file)
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
logger.debug("mkdir %s", dirname)
|
||||||
|
os.makedirs(dirname)
|
||||||
|
logger.debug("mv %s %s", original_file, dest_file)
|
||||||
|
shutil.move(original_file, dest_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=__doc__,
|
||||||
|
formatter_class = argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-v", action='store_true', help='enable debug logging')
|
||||||
|
parser.add_argument(
|
||||||
|
"src_repo",
|
||||||
|
help="Path to source content repo",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"dest_repo",
|
||||||
|
help="Path to source content repo",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging_config = {
|
||||||
|
"level": logging.DEBUG if args.v else logging.INFO,
|
||||||
|
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
|
||||||
|
}
|
||||||
|
logging.basicConfig(**logging_config)
|
||||||
|
|
||||||
|
main(args.src_repo, args.dest_repo)
|
@ -16,4 +16,4 @@
|
|||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.25.1"
|
__version__ = "0.26.0"
|
||||||
|
@ -46,6 +46,7 @@ class Codes(object):
|
|||||||
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
||||||
THREEPID_IN_USE = "M_THREEPID_IN_USE"
|
THREEPID_IN_USE = "M_THREEPID_IN_USE"
|
||||||
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
|
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
|
||||||
|
THREEPID_DENIED = "M_THREEPID_DENIED"
|
||||||
INVALID_USERNAME = "M_INVALID_USERNAME"
|
INVALID_USERNAME = "M_INVALID_USERNAME"
|
||||||
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||||
|
|
||||||
@ -140,6 +141,32 @@ class RegistrationError(SynapseError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FederationDeniedError(SynapseError):
|
||||||
|
"""An error raised when the server tries to federate with a server which
|
||||||
|
is not on its federation whitelist.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
destination (str): The destination which has been denied
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, destination):
|
||||||
|
"""Raised by federation client or server to indicate that we are
|
||||||
|
are deliberately not attempting to contact a given server because it is
|
||||||
|
not on our federation whitelist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (str): the domain in question
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.destination = destination
|
||||||
|
|
||||||
|
super(FederationDeniedError, self).__init__(
|
||||||
|
code=403,
|
||||||
|
msg="Federation denied with %s." % (self.destination,),
|
||||||
|
errcode=Codes.FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InteractiveAuthIncompleteError(Exception):
|
class InteractiveAuthIncompleteError(Exception):
|
||||||
"""An error raised when UI auth is not yet complete
|
"""An error raised when UI auth is not yet complete
|
||||||
|
|
||||||
|
@ -49,19 +49,6 @@ class AppserviceSlaveStore(
|
|||||||
|
|
||||||
|
|
||||||
class AppserviceServer(HomeServer):
|
class AppserviceServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
|
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
|
||||||
|
@ -64,19 +64,6 @@ class ClientReaderSlavedStore(
|
|||||||
|
|
||||||
|
|
||||||
class ClientReaderServer(HomeServer):
|
class ClientReaderServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
|
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
|
||||||
|
@ -58,19 +58,6 @@ class FederationReaderSlavedStore(
|
|||||||
|
|
||||||
|
|
||||||
class FederationReaderServer(HomeServer):
|
class FederationReaderServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
|
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
|
||||||
|
@ -76,19 +76,6 @@ class FederationSenderSlaveStore(
|
|||||||
|
|
||||||
|
|
||||||
class FederationSenderServer(HomeServer):
|
class FederationSenderServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
|
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
|
||||||
|
@ -118,19 +118,6 @@ class FrontendProxySlavedStore(
|
|||||||
|
|
||||||
|
|
||||||
class FrontendProxyServer(HomeServer):
|
class FrontendProxyServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
|
self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
|
||||||
|
@ -266,19 +266,6 @@ class SynapseHomeServer(HomeServer):
|
|||||||
except IncorrectDatabaseSetup as e:
|
except IncorrectDatabaseSetup as e:
|
||||||
quit_with_error(e.message)
|
quit_with_error(e.message)
|
||||||
|
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
|
|
||||||
def setup(config_options):
|
def setup(config_options):
|
||||||
"""
|
"""
|
||||||
|
@ -60,19 +60,6 @@ class MediaRepositorySlavedStore(
|
|||||||
|
|
||||||
|
|
||||||
class MediaRepositoryServer(HomeServer):
|
class MediaRepositoryServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
|
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
|
||||||
|
@ -81,19 +81,6 @@ class PusherSlaveStore(
|
|||||||
|
|
||||||
|
|
||||||
class PusherServer(HomeServer):
|
class PusherServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
|
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
|
||||||
|
@ -246,19 +246,6 @@ class SynchrotronApplicationService(object):
|
|||||||
|
|
||||||
|
|
||||||
class SynchrotronServer(HomeServer):
|
class SynchrotronServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
|
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
|
||||||
|
@ -184,6 +184,9 @@ def main():
|
|||||||
worker_configfiles.append(worker_configfile)
|
worker_configfiles.append(worker_configfile)
|
||||||
|
|
||||||
if options.all_processes:
|
if options.all_processes:
|
||||||
|
# To start the main synapse with -a you need to add a worker file
|
||||||
|
# with worker_app == "synapse.app.homeserver"
|
||||||
|
start_stop_synapse = False
|
||||||
worker_configdir = options.all_processes
|
worker_configdir = options.all_processes
|
||||||
if not os.path.isdir(worker_configdir):
|
if not os.path.isdir(worker_configdir):
|
||||||
write(
|
write(
|
||||||
@ -200,11 +203,29 @@ def main():
|
|||||||
with open(worker_configfile) as stream:
|
with open(worker_configfile) as stream:
|
||||||
worker_config = yaml.load(stream)
|
worker_config = yaml.load(stream)
|
||||||
worker_app = worker_config["worker_app"]
|
worker_app = worker_config["worker_app"]
|
||||||
worker_pidfile = worker_config["worker_pid_file"]
|
if worker_app == "synapse.app.homeserver":
|
||||||
worker_daemonize = worker_config["worker_daemonize"]
|
# We need to special case all of this to pick up options that may
|
||||||
assert worker_daemonize, "In config %r: expected '%s' to be True" % (
|
# be set in the main config file or in this worker config file.
|
||||||
worker_configfile, "worker_daemonize")
|
worker_pidfile = (
|
||||||
worker_cache_factor = worker_config.get("synctl_cache_factor")
|
worker_config.get("pid_file")
|
||||||
|
or pidfile
|
||||||
|
)
|
||||||
|
worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor
|
||||||
|
daemonize = worker_config.get("daemonize") or config.get("daemonize")
|
||||||
|
assert daemonize, "Main process must have daemonize set to true"
|
||||||
|
|
||||||
|
# The master process doesn't support using worker_* config.
|
||||||
|
for key in worker_config:
|
||||||
|
if key == "worker_app": # But we allow worker_app
|
||||||
|
continue
|
||||||
|
assert not key.startswith("worker_"), \
|
||||||
|
"Main process cannot use worker_* config"
|
||||||
|
else:
|
||||||
|
worker_pidfile = worker_config["worker_pid_file"]
|
||||||
|
worker_daemonize = worker_config["worker_daemonize"]
|
||||||
|
assert worker_daemonize, "In config %r: expected '%s' to be True" % (
|
||||||
|
worker_configfile, "worker_daemonize")
|
||||||
|
worker_cache_factor = worker_config.get("synctl_cache_factor")
|
||||||
workers.append(Worker(
|
workers.append(Worker(
|
||||||
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
|
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
|
||||||
))
|
))
|
||||||
|
@ -92,19 +92,6 @@ class UserDirectorySlaveStore(
|
|||||||
|
|
||||||
|
|
||||||
class UserDirectoryServer(HomeServer):
|
class UserDirectoryServer(HomeServer):
|
||||||
def get_db_conn(self, run_new_connection=True):
|
|
||||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
|
||||||
# not be passed to the database engine.
|
|
||||||
db_params = {
|
|
||||||
k: v for k, v in self.db_config.get("args", {}).items()
|
|
||||||
if not k.startswith("cp_")
|
|
||||||
}
|
|
||||||
db_conn = self.database_engine.module.connect(**db_params)
|
|
||||||
|
|
||||||
if run_new_connection:
|
|
||||||
self.database_engine.on_new_connection(db_conn)
|
|
||||||
return db_conn
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
logger.info("Setting up.")
|
logger.info("Setting up.")
|
||||||
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
|
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
|
||||||
|
@ -28,27 +28,27 @@ DEFAULT_LOG_CONFIG = Template("""
|
|||||||
version: 1
|
version: 1
|
||||||
|
|
||||||
formatters:
|
formatters:
|
||||||
precise:
|
precise:
|
||||||
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
|
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
|
||||||
- %(message)s'
|
%(request)s - %(message)s'
|
||||||
|
|
||||||
filters:
|
filters:
|
||||||
context:
|
context:
|
||||||
(): synapse.util.logcontext.LoggingContextFilter
|
(): synapse.util.logcontext.LoggingContextFilter
|
||||||
request: ""
|
request: ""
|
||||||
|
|
||||||
handlers:
|
handlers:
|
||||||
file:
|
file:
|
||||||
class: logging.handlers.RotatingFileHandler
|
class: logging.handlers.RotatingFileHandler
|
||||||
formatter: precise
|
formatter: precise
|
||||||
filename: ${log_file}
|
filename: ${log_file}
|
||||||
maxBytes: 104857600
|
maxBytes: 104857600
|
||||||
backupCount: 10
|
backupCount: 10
|
||||||
filters: [context]
|
filters: [context]
|
||||||
console:
|
console:
|
||||||
class: logging.StreamHandler
|
class: logging.StreamHandler
|
||||||
formatter: precise
|
formatter: precise
|
||||||
filters: [context]
|
filters: [context]
|
||||||
|
|
||||||
loggers:
|
loggers:
|
||||||
synapse:
|
synapse:
|
||||||
@ -74,17 +74,10 @@ class LoggingConfig(Config):
|
|||||||
self.log_file = self.abspath(config.get("log_file"))
|
self.log_file = self.abspath(config.get("log_file"))
|
||||||
|
|
||||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||||
log_file = self.abspath("homeserver.log")
|
|
||||||
log_config = self.abspath(
|
log_config = self.abspath(
|
||||||
os.path.join(config_dir_path, server_name + ".log.config")
|
os.path.join(config_dir_path, server_name + ".log.config")
|
||||||
)
|
)
|
||||||
return """
|
return """
|
||||||
# Logging verbosity level. Ignored if log_config is specified.
|
|
||||||
verbose: 0
|
|
||||||
|
|
||||||
# File to write logging to. Ignored if log_config is specified.
|
|
||||||
log_file: "%(log_file)s"
|
|
||||||
|
|
||||||
# A yaml python logging config file
|
# A yaml python logging config file
|
||||||
log_config: "%(log_config)s"
|
log_config: "%(log_config)s"
|
||||||
""" % locals()
|
""" % locals()
|
||||||
@ -123,9 +116,10 @@ class LoggingConfig(Config):
|
|||||||
def generate_files(self, config):
|
def generate_files(self, config):
|
||||||
log_config = config.get("log_config")
|
log_config = config.get("log_config")
|
||||||
if log_config and not os.path.exists(log_config):
|
if log_config and not os.path.exists(log_config):
|
||||||
|
log_file = self.abspath("homeserver.log")
|
||||||
with open(log_config, "wb") as log_config_file:
|
with open(log_config, "wb") as log_config_file:
|
||||||
log_config_file.write(
|
log_config_file.write(
|
||||||
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
|
DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -150,6 +144,9 @@ def setup_logging(config, use_worker_options=False):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if log_config is None:
|
if log_config is None:
|
||||||
|
# We don't have a logfile, so fall back to the 'verbosity' param from
|
||||||
|
# the config or cmdline. (Note that we generate a log config for new
|
||||||
|
# installs, so this will be an unusual case)
|
||||||
level = logging.INFO
|
level = logging.INFO
|
||||||
level_for_storage = logging.INFO
|
level_for_storage = logging.INFO
|
||||||
if config.verbosity:
|
if config.verbosity:
|
||||||
@ -157,11 +154,10 @@ def setup_logging(config, use_worker_options=False):
|
|||||||
if config.verbosity > 1:
|
if config.verbosity > 1:
|
||||||
level_for_storage = logging.DEBUG
|
level_for_storage = logging.DEBUG
|
||||||
|
|
||||||
# FIXME: we need a logging.WARN for a -q quiet option
|
|
||||||
logger = logging.getLogger('')
|
logger = logging.getLogger('')
|
||||||
logger.setLevel(level)
|
logger.setLevel(level)
|
||||||
|
|
||||||
logging.getLogger('synapse.storage').setLevel(level_for_storage)
|
logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
|
||||||
|
|
||||||
formatter = logging.Formatter(log_format)
|
formatter = logging.Formatter(log_format)
|
||||||
if log_file:
|
if log_file:
|
||||||
|
@ -31,6 +31,8 @@ class RegistrationConfig(Config):
|
|||||||
strtobool(str(config["disable_registration"]))
|
strtobool(str(config["disable_registration"]))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
|
||||||
|
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
|
||||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||||
|
|
||||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||||
@ -52,6 +54,23 @@ class RegistrationConfig(Config):
|
|||||||
# Enable registration for new users.
|
# Enable registration for new users.
|
||||||
enable_registration: False
|
enable_registration: False
|
||||||
|
|
||||||
|
# The user must provide all of the below types of 3PID when registering.
|
||||||
|
#
|
||||||
|
# registrations_require_3pid:
|
||||||
|
# - email
|
||||||
|
# - msisdn
|
||||||
|
|
||||||
|
# Mandate that users are only allowed to associate certain formats of
|
||||||
|
# 3PIDs with accounts on this server.
|
||||||
|
#
|
||||||
|
# allowed_local_3pids:
|
||||||
|
# - medium: email
|
||||||
|
# pattern: ".*@matrix\\.org"
|
||||||
|
# - medium: email
|
||||||
|
# pattern: ".*@vector\\.im"
|
||||||
|
# - medium: msisdn
|
||||||
|
# pattern: "\\+44"
|
||||||
|
|
||||||
# If set, allows registration by anyone who also has the shared
|
# If set, allows registration by anyone who also has the shared
|
||||||
# secret, even if registration is otherwise disabled.
|
# secret, even if registration is otherwise disabled.
|
||||||
registration_shared_secret: "%(registration_shared_secret)s"
|
registration_shared_secret: "%(registration_shared_secret)s"
|
||||||
|
@ -16,6 +16,8 @@
|
|||||||
from ._base import Config, ConfigError
|
from ._base import Config, ConfigError
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
|
from synapse.util.module_loader import load_module
|
||||||
|
|
||||||
|
|
||||||
MISSING_NETADDR = (
|
MISSING_NETADDR = (
|
||||||
"Missing netaddr library. This is required for URL preview API."
|
"Missing netaddr library. This is required for URL preview API."
|
||||||
@ -36,6 +38,14 @@ ThumbnailRequirement = namedtuple(
|
|||||||
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
|
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MediaStorageProviderConfig = namedtuple(
|
||||||
|
"MediaStorageProviderConfig", (
|
||||||
|
"store_local", # Whether to store newly uploaded local files
|
||||||
|
"store_remote", # Whether to store newly downloaded remote files
|
||||||
|
"store_synchronous", # Whether to wait for successful storage for local uploads
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_thumbnail_requirements(thumbnail_sizes):
|
def parse_thumbnail_requirements(thumbnail_sizes):
|
||||||
""" Takes a list of dictionaries with "width", "height", and "method" keys
|
""" Takes a list of dictionaries with "width", "height", and "method" keys
|
||||||
@ -73,16 +83,61 @@ class ContentRepositoryConfig(Config):
|
|||||||
|
|
||||||
self.media_store_path = self.ensure_directory(config["media_store_path"])
|
self.media_store_path = self.ensure_directory(config["media_store_path"])
|
||||||
|
|
||||||
self.backup_media_store_path = config.get("backup_media_store_path")
|
backup_media_store_path = config.get("backup_media_store_path")
|
||||||
if self.backup_media_store_path:
|
|
||||||
self.backup_media_store_path = self.ensure_directory(
|
|
||||||
self.backup_media_store_path
|
|
||||||
)
|
|
||||||
|
|
||||||
self.synchronous_backup_media_store = config.get(
|
synchronous_backup_media_store = config.get(
|
||||||
"synchronous_backup_media_store", False
|
"synchronous_backup_media_store", False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
storage_providers = config.get("media_storage_providers", [])
|
||||||
|
|
||||||
|
if backup_media_store_path:
|
||||||
|
if storage_providers:
|
||||||
|
raise ConfigError(
|
||||||
|
"Cannot use both 'backup_media_store_path' and 'storage_providers'"
|
||||||
|
)
|
||||||
|
|
||||||
|
storage_providers = [{
|
||||||
|
"module": "file_system",
|
||||||
|
"store_local": True,
|
||||||
|
"store_synchronous": synchronous_backup_media_store,
|
||||||
|
"store_remote": True,
|
||||||
|
"config": {
|
||||||
|
"directory": backup_media_store_path,
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
|
||||||
|
# This is a list of config that can be used to create the storage
|
||||||
|
# providers. The entries are tuples of (Class, class_config,
|
||||||
|
# MediaStorageProviderConfig), where Class is the class of the provider,
|
||||||
|
# the class_config the config to pass to it, and
|
||||||
|
# MediaStorageProviderConfig are options for StorageProviderWrapper.
|
||||||
|
#
|
||||||
|
# We don't create the storage providers here as not all workers need
|
||||||
|
# them to be started.
|
||||||
|
self.media_storage_providers = []
|
||||||
|
|
||||||
|
for provider_config in storage_providers:
|
||||||
|
# We special case the module "file_system" so as not to need to
|
||||||
|
# expose FileStorageProviderBackend
|
||||||
|
if provider_config["module"] == "file_system":
|
||||||
|
provider_config["module"] = (
|
||||||
|
"synapse.rest.media.v1.storage_provider"
|
||||||
|
".FileStorageProviderBackend"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_class, parsed_config = load_module(provider_config)
|
||||||
|
|
||||||
|
wrapper_config = MediaStorageProviderConfig(
|
||||||
|
provider_config.get("store_local", False),
|
||||||
|
provider_config.get("store_remote", False),
|
||||||
|
provider_config.get("store_synchronous", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.media_storage_providers.append(
|
||||||
|
(provider_class, parsed_config, wrapper_config,)
|
||||||
|
)
|
||||||
|
|
||||||
self.uploads_path = self.ensure_directory(config["uploads_path"])
|
self.uploads_path = self.ensure_directory(config["uploads_path"])
|
||||||
self.dynamic_thumbnails = config["dynamic_thumbnails"]
|
self.dynamic_thumbnails = config["dynamic_thumbnails"]
|
||||||
self.thumbnail_requirements = parse_thumbnail_requirements(
|
self.thumbnail_requirements = parse_thumbnail_requirements(
|
||||||
@ -127,13 +182,19 @@ class ContentRepositoryConfig(Config):
|
|||||||
# Directory where uploaded images and attachments are stored.
|
# Directory where uploaded images and attachments are stored.
|
||||||
media_store_path: "%(media_store)s"
|
media_store_path: "%(media_store)s"
|
||||||
|
|
||||||
# A secondary directory where uploaded images and attachments are
|
# Media storage providers allow media to be stored in different
|
||||||
# stored as a backup.
|
# locations.
|
||||||
# backup_media_store_path: "%(media_store)s"
|
# media_storage_providers:
|
||||||
|
# - module: file_system
|
||||||
# Whether to wait for successful write to backup media store before
|
# # Whether to write new local files.
|
||||||
# returning successfully.
|
# store_local: false
|
||||||
# synchronous_backup_media_store: false
|
# # Whether to write new remote media
|
||||||
|
# store_remote: false
|
||||||
|
# # Whether to block upload requests waiting for write to this
|
||||||
|
# # provider to complete
|
||||||
|
# store_synchronous: false
|
||||||
|
# config:
|
||||||
|
# directory: /mnt/some/other/directory
|
||||||
|
|
||||||
# Directory where in-progress uploads are stored.
|
# Directory where in-progress uploads are stored.
|
||||||
uploads_path: "%(uploads_path)s"
|
uploads_path: "%(uploads_path)s"
|
||||||
|
@ -55,6 +55,17 @@ class ServerConfig(Config):
|
|||||||
"block_non_admin_invites", False,
|
"block_non_admin_invites", False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# FIXME: federation_domain_whitelist needs sytests
|
||||||
|
self.federation_domain_whitelist = None
|
||||||
|
federation_domain_whitelist = config.get(
|
||||||
|
"federation_domain_whitelist", None
|
||||||
|
)
|
||||||
|
# turn the whitelist into a hash for speed of lookup
|
||||||
|
if federation_domain_whitelist is not None:
|
||||||
|
self.federation_domain_whitelist = {}
|
||||||
|
for domain in federation_domain_whitelist:
|
||||||
|
self.federation_domain_whitelist[domain] = True
|
||||||
|
|
||||||
if self.public_baseurl is not None:
|
if self.public_baseurl is not None:
|
||||||
if self.public_baseurl[-1] != '/':
|
if self.public_baseurl[-1] != '/':
|
||||||
self.public_baseurl += '/'
|
self.public_baseurl += '/'
|
||||||
@ -210,6 +221,17 @@ class ServerConfig(Config):
|
|||||||
# (except those sent by local server admins). The default is False.
|
# (except those sent by local server admins). The default is False.
|
||||||
# block_non_admin_invites: True
|
# block_non_admin_invites: True
|
||||||
|
|
||||||
|
# Restrict federation to the following whitelist of domains.
|
||||||
|
# N.B. we recommend also firewalling your federation listener to limit
|
||||||
|
# inbound federation traffic as early as possible, rather than relying
|
||||||
|
# purely on this application-layer restriction. If not specified, the
|
||||||
|
# default is to whitelist everything.
|
||||||
|
#
|
||||||
|
# federation_domain_whitelist:
|
||||||
|
# - lon.example.com
|
||||||
|
# - nyc.example.com
|
||||||
|
# - syd.example.com
|
||||||
|
|
||||||
# List of ports that Synapse should listen on, their purpose and their
|
# List of ports that Synapse should listen on, their purpose and their
|
||||||
# configuration.
|
# configuration.
|
||||||
listeners:
|
listeners:
|
||||||
|
@ -96,7 +96,7 @@ class TlsConfig(Config):
|
|||||||
# certificates returned by this server match one of the fingerprints.
|
# certificates returned by this server match one of the fingerprints.
|
||||||
#
|
#
|
||||||
# Synapse automatically adds the fingerprint of its own certificate
|
# Synapse automatically adds the fingerprint of its own certificate
|
||||||
# to the list. So if federation traffic is handle directly by synapse
|
# to the list. So if federation traffic is handled directly by synapse
|
||||||
# then no modification to the list is required.
|
# then no modification to the list is required.
|
||||||
#
|
#
|
||||||
# If synapse is run behind a load balancer that handles the TLS then it
|
# If synapse is run behind a load balancer that handles the TLS then it
|
||||||
|
@ -23,6 +23,11 @@ class WorkerConfig(Config):
|
|||||||
|
|
||||||
def read_config(self, config):
|
def read_config(self, config):
|
||||||
self.worker_app = config.get("worker_app")
|
self.worker_app = config.get("worker_app")
|
||||||
|
|
||||||
|
# Canonicalise worker_app so that master always has None
|
||||||
|
if self.worker_app == "synapse.app.homeserver":
|
||||||
|
self.worker_app = None
|
||||||
|
|
||||||
self.worker_listeners = config.get("worker_listeners")
|
self.worker_listeners = config.get("worker_listeners")
|
||||||
self.worker_daemonize = config.get("worker_daemonize")
|
self.worker_daemonize = config.get("worker_daemonize")
|
||||||
self.worker_pid_file = config.get("worker_pid_file")
|
self.worker_pid_file = config.get("worker_pid_file")
|
||||||
|
@ -319,7 +319,7 @@ def _is_membership_change_allowed(event, auth_events):
|
|||||||
# TODO (erikj): Implement kicks.
|
# TODO (erikj): Implement kicks.
|
||||||
if target_banned and user_level < ban_level:
|
if target_banned and user_level < ban_level:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403, "You cannot unban user &s." % (target_user_id,)
|
403, "You cannot unban user %s." % (target_user_id,)
|
||||||
)
|
)
|
||||||
elif target_user_id != event.user_id:
|
elif target_user_id != event.user_id:
|
||||||
kick_level = _get_named_level(auth_events, "kick", 50)
|
kick_level = _get_named_level(auth_events, "kick", 50)
|
||||||
|
@ -16,7 +16,9 @@ import logging
|
|||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.crypto.event_signing import check_event_content_hash
|
from synapse.crypto.event_signing import check_event_content_hash
|
||||||
|
from synapse.events import FrozenEvent
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
|
from synapse.http.servlet import assert_params_in_request
|
||||||
from synapse.util import unwrapFirstError, logcontext
|
from synapse.util import unwrapFirstError, logcontext
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
@ -169,3 +171,28 @@ class FederationBase(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return deferreds
|
return deferreds
|
||||||
|
|
||||||
|
|
||||||
|
def event_from_pdu_json(pdu_json, outlier=False):
|
||||||
|
"""Construct a FrozenEvent from an event json received over federation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdu_json (object): pdu as received over federation
|
||||||
|
outlier (bool): True to mark this event as an outlier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FrozenEvent
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SynapseError: if the pdu is missing required fields
|
||||||
|
"""
|
||||||
|
# we could probably enforce a bunch of other fields here (room_id, sender,
|
||||||
|
# origin, etc etc)
|
||||||
|
assert_params_in_request(pdu_json, ('event_id', 'type'))
|
||||||
|
event = FrozenEvent(
|
||||||
|
pdu_json
|
||||||
|
)
|
||||||
|
|
||||||
|
event.internal_metadata.outlier = outlier
|
||||||
|
|
||||||
|
return event
|
||||||
|
@ -14,28 +14,28 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from .federation_base import FederationBase
|
|
||||||
from synapse.api.constants import Membership
|
|
||||||
|
|
||||||
from synapse.api.errors import (
|
|
||||||
CodeMessageException, HttpResponseException, SynapseError,
|
|
||||||
)
|
|
||||||
from synapse.util import unwrapFirstError, logcontext
|
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
|
||||||
from synapse.util.logutils import log_function
|
|
||||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
|
||||||
from synapse.events import FrozenEvent, builder
|
|
||||||
import synapse.metrics
|
|
||||||
|
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.constants import Membership
|
||||||
|
from synapse.api.errors import (
|
||||||
|
CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
|
||||||
|
)
|
||||||
|
from synapse.events import builder
|
||||||
|
from synapse.federation.federation_base import (
|
||||||
|
FederationBase,
|
||||||
|
event_from_pdu_json,
|
||||||
|
)
|
||||||
|
import synapse.metrics
|
||||||
|
from synapse.util import logcontext, unwrapFirstError
|
||||||
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||||
|
from synapse.util.logutils import log_function
|
||||||
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -184,7 +184,7 @@ class FederationClient(FederationBase):
|
|||||||
logger.debug("backfill transaction_data=%s", repr(transaction_data))
|
logger.debug("backfill transaction_data=%s", repr(transaction_data))
|
||||||
|
|
||||||
pdus = [
|
pdus = [
|
||||||
self.event_from_pdu_json(p, outlier=False)
|
event_from_pdu_json(p, outlier=False)
|
||||||
for p in transaction_data["pdus"]
|
for p in transaction_data["pdus"]
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -244,7 +244,7 @@ class FederationClient(FederationBase):
|
|||||||
logger.debug("transaction_data %r", transaction_data)
|
logger.debug("transaction_data %r", transaction_data)
|
||||||
|
|
||||||
pdu_list = [
|
pdu_list = [
|
||||||
self.event_from_pdu_json(p, outlier=outlier)
|
event_from_pdu_json(p, outlier=outlier)
|
||||||
for p in transaction_data["pdus"]
|
for p in transaction_data["pdus"]
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -266,6 +266,9 @@ class FederationClient(FederationBase):
|
|||||||
except NotRetryingDestination as e:
|
except NotRetryingDestination as e:
|
||||||
logger.info(e.message)
|
logger.info(e.message)
|
||||||
continue
|
continue
|
||||||
|
except FederationDeniedError as e:
|
||||||
|
logger.info(e.message)
|
||||||
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pdu_attempts[destination] = now
|
pdu_attempts[destination] = now
|
||||||
|
|
||||||
@ -336,11 +339,11 @@ class FederationClient(FederationBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pdus = [
|
pdus = [
|
||||||
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
|
event_from_pdu_json(p, outlier=True) for p in result["pdus"]
|
||||||
]
|
]
|
||||||
|
|
||||||
auth_chain = [
|
auth_chain = [
|
||||||
self.event_from_pdu_json(p, outlier=True)
|
event_from_pdu_json(p, outlier=True)
|
||||||
for p in result.get("auth_chain", [])
|
for p in result.get("auth_chain", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -441,7 +444,7 @@ class FederationClient(FederationBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
auth_chain = [
|
auth_chain = [
|
||||||
self.event_from_pdu_json(p, outlier=True)
|
event_from_pdu_json(p, outlier=True)
|
||||||
for p in res["auth_chain"]
|
for p in res["auth_chain"]
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -570,12 +573,12 @@ class FederationClient(FederationBase):
|
|||||||
logger.debug("Got content: %s", content)
|
logger.debug("Got content: %s", content)
|
||||||
|
|
||||||
state = [
|
state = [
|
||||||
self.event_from_pdu_json(p, outlier=True)
|
event_from_pdu_json(p, outlier=True)
|
||||||
for p in content.get("state", [])
|
for p in content.get("state", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
auth_chain = [
|
auth_chain = [
|
||||||
self.event_from_pdu_json(p, outlier=True)
|
event_from_pdu_json(p, outlier=True)
|
||||||
for p in content.get("auth_chain", [])
|
for p in content.get("auth_chain", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -650,7 +653,7 @@ class FederationClient(FederationBase):
|
|||||||
|
|
||||||
logger.debug("Got response to send_invite: %s", pdu_dict)
|
logger.debug("Got response to send_invite: %s", pdu_dict)
|
||||||
|
|
||||||
pdu = self.event_from_pdu_json(pdu_dict)
|
pdu = event_from_pdu_json(pdu_dict)
|
||||||
|
|
||||||
# Check signatures are correct.
|
# Check signatures are correct.
|
||||||
pdu = yield self._check_sigs_and_hash(pdu)
|
pdu = yield self._check_sigs_and_hash(pdu)
|
||||||
@ -740,7 +743,7 @@ class FederationClient(FederationBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
auth_chain = [
|
auth_chain = [
|
||||||
self.event_from_pdu_json(e)
|
event_from_pdu_json(e)
|
||||||
for e in content["auth_chain"]
|
for e in content["auth_chain"]
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -788,7 +791,7 @@ class FederationClient(FederationBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
events = [
|
events = [
|
||||||
self.event_from_pdu_json(e)
|
event_from_pdu_json(e)
|
||||||
for e in content.get("events", [])
|
for e in content.get("events", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -805,15 +808,6 @@ class FederationClient(FederationBase):
|
|||||||
|
|
||||||
defer.returnValue(signed_events)
|
defer.returnValue(signed_events)
|
||||||
|
|
||||||
def event_from_pdu_json(self, pdu_json, outlier=False):
|
|
||||||
event = FrozenEvent(
|
|
||||||
pdu_json
|
|
||||||
)
|
|
||||||
|
|
||||||
event.internal_metadata.outlier = outlier
|
|
||||||
|
|
||||||
return event
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def forward_third_party_invite(self, destinations, room_id, event_dict):
|
def forward_third_party_invite(self, destinations, room_id, event_dict):
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
|
@ -12,25 +12,24 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from twisted.internet import defer
|
import logging
|
||||||
|
|
||||||
from .federation_base import FederationBase
|
|
||||||
from .units import Transaction, Edu
|
|
||||||
|
|
||||||
from synapse.util import async
|
|
||||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
|
||||||
from synapse.util.logutils import log_function
|
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
|
||||||
from synapse.events import FrozenEvent
|
|
||||||
from synapse.types import get_domain_from_id
|
|
||||||
import synapse.metrics
|
|
||||||
|
|
||||||
from synapse.api.errors import AuthError, FederationError, SynapseError
|
|
||||||
|
|
||||||
from synapse.crypto.event_signing import compute_event_signature
|
|
||||||
|
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
import logging
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import AuthError, FederationError, SynapseError
|
||||||
|
from synapse.crypto.event_signing import compute_event_signature
|
||||||
|
from synapse.federation.federation_base import (
|
||||||
|
FederationBase,
|
||||||
|
event_from_pdu_json,
|
||||||
|
)
|
||||||
|
from synapse.federation.units import Edu, Transaction
|
||||||
|
import synapse.metrics
|
||||||
|
from synapse.types import get_domain_from_id
|
||||||
|
from synapse.util import async
|
||||||
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||||
|
from synapse.util.logutils import log_function
|
||||||
|
|
||||||
# when processing incoming transactions, we try to handle multiple rooms in
|
# when processing incoming transactions, we try to handle multiple rooms in
|
||||||
# parallel, up to this limit.
|
# parallel, up to this limit.
|
||||||
@ -172,7 +171,7 @@ class FederationServer(FederationBase):
|
|||||||
p["age_ts"] = request_time - int(p["age"])
|
p["age_ts"] = request_time - int(p["age"])
|
||||||
del p["age"]
|
del p["age"]
|
||||||
|
|
||||||
event = self.event_from_pdu_json(p)
|
event = event_from_pdu_json(p)
|
||||||
room_id = event.room_id
|
room_id = event.room_id
|
||||||
pdus_by_room.setdefault(room_id, []).append(event)
|
pdus_by_room.setdefault(room_id, []).append(event)
|
||||||
|
|
||||||
@ -346,7 +345,7 @@ class FederationServer(FederationBase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_invite_request(self, origin, content):
|
def on_invite_request(self, origin, content):
|
||||||
pdu = self.event_from_pdu_json(content)
|
pdu = event_from_pdu_json(content)
|
||||||
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
|
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
|
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
|
||||||
@ -354,7 +353,7 @@ class FederationServer(FederationBase):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_send_join_request(self, origin, content):
|
def on_send_join_request(self, origin, content):
|
||||||
logger.debug("on_send_join_request: content: %s", content)
|
logger.debug("on_send_join_request: content: %s", content)
|
||||||
pdu = self.event_from_pdu_json(content)
|
pdu = event_from_pdu_json(content)
|
||||||
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
|
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
|
||||||
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
|
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
@ -374,7 +373,7 @@ class FederationServer(FederationBase):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_send_leave_request(self, origin, content):
|
def on_send_leave_request(self, origin, content):
|
||||||
logger.debug("on_send_leave_request: content: %s", content)
|
logger.debug("on_send_leave_request: content: %s", content)
|
||||||
pdu = self.event_from_pdu_json(content)
|
pdu = event_from_pdu_json(content)
|
||||||
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
|
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
|
||||||
yield self.handler.on_send_leave_request(origin, pdu)
|
yield self.handler.on_send_leave_request(origin, pdu)
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
@ -411,7 +410,7 @@ class FederationServer(FederationBase):
|
|||||||
"""
|
"""
|
||||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||||
auth_chain = [
|
auth_chain = [
|
||||||
self.event_from_pdu_json(e)
|
event_from_pdu_json(e)
|
||||||
for e in content["auth_chain"]
|
for e in content["auth_chain"]
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -586,15 +585,6 @@ class FederationServer(FederationBase):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "<ReplicationLayer(%s)>" % self.server_name
|
return "<ReplicationLayer(%s)>" % self.server_name
|
||||||
|
|
||||||
def event_from_pdu_json(self, pdu_json, outlier=False):
|
|
||||||
event = FrozenEvent(
|
|
||||||
pdu_json
|
|
||||||
)
|
|
||||||
|
|
||||||
event.internal_metadata.outlier = outlier
|
|
||||||
|
|
||||||
return event
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def exchange_third_party_invite(
|
def exchange_third_party_invite(
|
||||||
self,
|
self,
|
||||||
|
@ -19,7 +19,7 @@ from twisted.internet import defer
|
|||||||
from .persistence import TransactionActions
|
from .persistence import TransactionActions
|
||||||
from .units import Transaction, Edu
|
from .units import Transaction, Edu
|
||||||
|
|
||||||
from synapse.api.errors import HttpResponseException
|
from synapse.api.errors import HttpResponseException, FederationDeniedError
|
||||||
from synapse.util import logcontext, PreserveLoggingContext
|
from synapse.util import logcontext, PreserveLoggingContext
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||||
@ -42,6 +42,8 @@ sent_edus_counter = client_metrics.register_counter("sent_edus")
|
|||||||
|
|
||||||
sent_transactions_counter = client_metrics.register_counter("sent_transactions")
|
sent_transactions_counter = client_metrics.register_counter("sent_transactions")
|
||||||
|
|
||||||
|
events_processed_counter = client_metrics.register_counter("events_processed")
|
||||||
|
|
||||||
|
|
||||||
class TransactionQueue(object):
|
class TransactionQueue(object):
|
||||||
"""This class makes sure we only have one transaction in flight at
|
"""This class makes sure we only have one transaction in flight at
|
||||||
@ -205,6 +207,8 @@ class TransactionQueue(object):
|
|||||||
|
|
||||||
self._send_pdu(event, destinations)
|
self._send_pdu(event, destinations)
|
||||||
|
|
||||||
|
events_processed_counter.inc_by(len(events))
|
||||||
|
|
||||||
yield self.store.update_federation_out_pos(
|
yield self.store.update_federation_out_pos(
|
||||||
"events", next_token
|
"events", next_token
|
||||||
)
|
)
|
||||||
@ -486,6 +490,8 @@ class TransactionQueue(object):
|
|||||||
(e.retry_last_ts + e.retry_interval) / 1000.0
|
(e.retry_last_ts + e.retry_interval) / 1000.0
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
except FederationDeniedError as e:
|
||||||
|
logger.info(e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"TX [%s] Failed to send transaction: %s",
|
"TX [%s] Failed to send transaction: %s",
|
||||||
|
@ -212,6 +212,9 @@ class TransportLayerClient(object):
|
|||||||
|
|
||||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
to retry this server.
|
to retry this server.
|
||||||
|
|
||||||
|
Fails with ``FederationDeniedError`` if the remote destination
|
||||||
|
is not in our federation whitelist
|
||||||
"""
|
"""
|
||||||
valid_memberships = {Membership.JOIN, Membership.LEAVE}
|
valid_memberships = {Membership.JOIN, Membership.LEAVE}
|
||||||
if membership not in valid_memberships:
|
if membership not in valid_memberships:
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError, FederationDeniedError
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
|
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
|
||||||
@ -81,6 +81,7 @@ class Authenticator(object):
|
|||||||
self.keyring = hs.get_keyring()
|
self.keyring = hs.get_keyring()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||||
|
|
||||||
# A method just so we can pass 'self' as the authenticator to the Servlets
|
# A method just so we can pass 'self' as the authenticator to the Servlets
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -92,6 +93,12 @@ class Authenticator(object):
|
|||||||
"signatures": {},
|
"signatures": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.federation_domain_whitelist is not None and
|
||||||
|
self.server_name not in self.federation_domain_whitelist
|
||||||
|
):
|
||||||
|
raise FederationDeniedError(self.server_name)
|
||||||
|
|
||||||
if content is not None:
|
if content is not None:
|
||||||
json_request["content"] = content
|
json_request["content"] = content
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||||
@ -23,6 +24,10 @@ import logging
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||||
|
|
||||||
|
events_processed_counter = metrics.register_counter("events_processed")
|
||||||
|
|
||||||
|
|
||||||
def log_failure(failure):
|
def log_failure(failure):
|
||||||
logger.error(
|
logger.error(
|
||||||
@ -103,6 +108,8 @@ class ApplicationServicesHandler(object):
|
|||||||
service, event
|
service, event
|
||||||
)
|
)
|
||||||
|
|
||||||
|
events_processed_counter.inc_by(len(events))
|
||||||
|
|
||||||
yield self.store.set_appservice_last_pos(upper_bound)
|
yield self.store.set_appservice_last_pos(upper_bound)
|
||||||
finally:
|
finally:
|
||||||
self.is_processing = False
|
self.is_processing = False
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer, threads
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
@ -25,6 +25,7 @@ from synapse.module_api import ModuleApi
|
|||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
|
|
||||||
from twisted.web.client import PartialDownloadError
|
from twisted.web.client import PartialDownloadError
|
||||||
|
|
||||||
@ -714,7 +715,7 @@ class AuthHandler(BaseHandler):
|
|||||||
if not lookupres:
|
if not lookupres:
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
(user_id, password_hash) = lookupres
|
(user_id, password_hash) = lookupres
|
||||||
result = self.validate_hash(password, password_hash)
|
result = yield self.validate_hash(password, password_hash)
|
||||||
if not result:
|
if not result:
|
||||||
logger.warn("Failed password login for user %s", user_id)
|
logger.warn("Failed password login for user %s", user_id)
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
@ -842,10 +843,13 @@ class AuthHandler(BaseHandler):
|
|||||||
password (str): Password to hash.
|
password (str): Password to hash.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Hashed password (str).
|
Deferred(str): Hashed password.
|
||||||
"""
|
"""
|
||||||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
def _do_hash():
|
||||||
bcrypt.gensalt(self.bcrypt_rounds))
|
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
||||||
|
bcrypt.gensalt(self.bcrypt_rounds))
|
||||||
|
|
||||||
|
return make_deferred_yieldable(threads.deferToThread(_do_hash))
|
||||||
|
|
||||||
def validate_hash(self, password, stored_hash):
|
def validate_hash(self, password, stored_hash):
|
||||||
"""Validates that self.hash(password) == stored_hash.
|
"""Validates that self.hash(password) == stored_hash.
|
||||||
@ -855,13 +859,17 @@ class AuthHandler(BaseHandler):
|
|||||||
stored_hash (str): Expected hash value.
|
stored_hash (str): Expected hash value.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Whether self.hash(password) == stored_hash (bool).
|
Deferred(bool): Whether self.hash(password) == stored_hash.
|
||||||
"""
|
"""
|
||||||
if stored_hash:
|
|
||||||
|
def _do_validate_hash():
|
||||||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
||||||
stored_hash.encode('utf8')) == stored_hash
|
stored_hash.encode('utf8')) == stored_hash
|
||||||
|
|
||||||
|
if stored_hash:
|
||||||
|
return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
|
||||||
else:
|
else:
|
||||||
return False
|
return defer.succeed(False)
|
||||||
|
|
||||||
|
|
||||||
class MacaroonGeneartor(object):
|
class MacaroonGeneartor(object):
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from synapse.api import errors
|
from synapse.api import errors
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
|
from synapse.api.errors import FederationDeniedError
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
@ -513,6 +514,9 @@ class DeviceListEduUpdater(object):
|
|||||||
# This makes it more likely that the device lists will
|
# This makes it more likely that the device lists will
|
||||||
# eventually become consistent.
|
# eventually become consistent.
|
||||||
return
|
return
|
||||||
|
except FederationDeniedError as e:
|
||||||
|
logger.info(e)
|
||||||
|
return
|
||||||
except Exception:
|
except Exception:
|
||||||
# TODO: Remember that we are now out of sync and try again
|
# TODO: Remember that we are now out of sync and try again
|
||||||
# later
|
# later
|
||||||
|
@ -17,7 +17,8 @@ import logging
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.types import get_domain_from_id, UserID
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
|
|
||||||
@ -33,7 +34,7 @@ class DeviceMessageHandler(object):
|
|||||||
"""
|
"""
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine = hs.is_mine
|
||||||
self.federation = hs.get_federation_sender()
|
self.federation = hs.get_federation_sender()
|
||||||
|
|
||||||
hs.get_replication_layer().register_edu_handler(
|
hs.get_replication_layer().register_edu_handler(
|
||||||
@ -52,6 +53,12 @@ class DeviceMessageHandler(object):
|
|||||||
message_type = content["type"]
|
message_type = content["type"]
|
||||||
message_id = content["message_id"]
|
message_id = content["message_id"]
|
||||||
for user_id, by_device in content["messages"].items():
|
for user_id, by_device in content["messages"].items():
|
||||||
|
# we use UserID.from_string to catch invalid user ids
|
||||||
|
if not self.is_mine(UserID.from_string(user_id)):
|
||||||
|
logger.warning("Request for keys for non-local user %s",
|
||||||
|
user_id)
|
||||||
|
raise SynapseError(400, "Not a user here")
|
||||||
|
|
||||||
messages_by_device = {
|
messages_by_device = {
|
||||||
device_id: {
|
device_id: {
|
||||||
"content": message_content,
|
"content": message_content,
|
||||||
@ -77,7 +84,8 @@ class DeviceMessageHandler(object):
|
|||||||
local_messages = {}
|
local_messages = {}
|
||||||
remote_messages = {}
|
remote_messages = {}
|
||||||
for user_id, by_device in messages.items():
|
for user_id, by_device in messages.items():
|
||||||
if self.is_mine_id(user_id):
|
# we use UserID.from_string to catch invalid user ids
|
||||||
|
if self.is_mine(UserID.from_string(user_id)):
|
||||||
messages_by_device = {
|
messages_by_device = {
|
||||||
device_id: {
|
device_id: {
|
||||||
"content": message_content,
|
"content": message_content,
|
||||||
|
@ -19,8 +19,10 @@ import logging
|
|||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, CodeMessageException
|
from synapse.api.errors import (
|
||||||
from synapse.types import get_domain_from_id
|
SynapseError, CodeMessageException, FederationDeniedError,
|
||||||
|
)
|
||||||
|
from synapse.types import get_domain_from_id, UserID
|
||||||
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
|
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
@ -32,7 +34,7 @@ class E2eKeysHandler(object):
|
|||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_replication_layer()
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine = hs.is_mine
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
# doesn't really work as part of the generic query API, because the
|
# doesn't really work as part of the generic query API, because the
|
||||||
@ -70,7 +72,8 @@ class E2eKeysHandler(object):
|
|||||||
remote_queries = {}
|
remote_queries = {}
|
||||||
|
|
||||||
for user_id, device_ids in device_keys_query.items():
|
for user_id, device_ids in device_keys_query.items():
|
||||||
if self.is_mine_id(user_id):
|
# we use UserID.from_string to catch invalid user ids
|
||||||
|
if self.is_mine(UserID.from_string(user_id)):
|
||||||
local_query[user_id] = device_ids
|
local_query[user_id] = device_ids
|
||||||
else:
|
else:
|
||||||
remote_queries[user_id] = device_ids
|
remote_queries[user_id] = device_ids
|
||||||
@ -139,6 +142,10 @@ class E2eKeysHandler(object):
|
|||||||
failures[destination] = {
|
failures[destination] = {
|
||||||
"status": 503, "message": "Not ready for retry",
|
"status": 503, "message": "Not ready for retry",
|
||||||
}
|
}
|
||||||
|
except FederationDeniedError as e:
|
||||||
|
failures[destination] = {
|
||||||
|
"status": 403, "message": "Federation Denied",
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# include ConnectionRefused and other errors
|
# include ConnectionRefused and other errors
|
||||||
failures[destination] = {
|
failures[destination] = {
|
||||||
@ -170,7 +177,8 @@ class E2eKeysHandler(object):
|
|||||||
|
|
||||||
result_dict = {}
|
result_dict = {}
|
||||||
for user_id, device_ids in query.items():
|
for user_id, device_ids in query.items():
|
||||||
if not self.is_mine_id(user_id):
|
# we use UserID.from_string to catch invalid user ids
|
||||||
|
if not self.is_mine(UserID.from_string(user_id)):
|
||||||
logger.warning("Request for keys for non-local user %s",
|
logger.warning("Request for keys for non-local user %s",
|
||||||
user_id)
|
user_id)
|
||||||
raise SynapseError(400, "Not a user here")
|
raise SynapseError(400, "Not a user here")
|
||||||
@ -213,7 +221,8 @@ class E2eKeysHandler(object):
|
|||||||
remote_queries = {}
|
remote_queries = {}
|
||||||
|
|
||||||
for user_id, device_keys in query.get("one_time_keys", {}).items():
|
for user_id, device_keys in query.get("one_time_keys", {}).items():
|
||||||
if self.is_mine_id(user_id):
|
# we use UserID.from_string to catch invalid user ids
|
||||||
|
if self.is_mine(UserID.from_string(user_id)):
|
||||||
for device_id, algorithm in device_keys.items():
|
for device_id, algorithm in device_keys.items():
|
||||||
local_query.append((user_id, device_id, algorithm))
|
local_query.append((user_id, device_id, algorithm))
|
||||||
else:
|
else:
|
||||||
|
@ -22,6 +22,7 @@ from ._base import BaseHandler
|
|||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
|
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
|
||||||
|
FederationDeniedError,
|
||||||
)
|
)
|
||||||
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
@ -782,6 +783,9 @@ class FederationHandler(BaseHandler):
|
|||||||
except NotRetryingDestination as e:
|
except NotRetryingDestination as e:
|
||||||
logger.info(e.message)
|
logger.info(e.message)
|
||||||
continue
|
continue
|
||||||
|
except FederationDeniedError as e:
|
||||||
|
logger.info(e)
|
||||||
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to backfill from %s because %s",
|
"Failed to backfill from %s because %s",
|
||||||
|
@ -383,11 +383,12 @@ class GroupsLocalHandler(object):
|
|||||||
|
|
||||||
defer.returnValue({"groups": result})
|
defer.returnValue({"groups": result})
|
||||||
else:
|
else:
|
||||||
result = yield self.transport_client.get_publicised_groups_for_user(
|
bulk_result = yield self.transport_client.bulk_get_publicised_groups(
|
||||||
get_domain_from_id(user_id), user_id
|
get_domain_from_id(user_id), [user_id],
|
||||||
)
|
)
|
||||||
|
result = bulk_result.get("users", {}).get(user_id)
|
||||||
# TODO: Verify attestations
|
# TODO: Verify attestations
|
||||||
defer.returnValue(result)
|
defer.returnValue({"groups": result})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def bulk_get_publicised_groups(self, user_ids, proxy=True):
|
def bulk_get_publicised_groups(self, user_ids, proxy=True):
|
||||||
|
@ -25,6 +25,7 @@ from synapse.http.client import CaptchaServerHttpClient
|
|||||||
from synapse import types
|
from synapse import types
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
|
from synapse.util.threepids import check_3pid_allowed
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -131,7 +132,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
password_hash = None
|
password_hash = None
|
||||||
if password:
|
if password:
|
||||||
password_hash = self.auth_handler().hash(password)
|
password_hash = yield self.auth_handler().hash(password)
|
||||||
|
|
||||||
if localpart:
|
if localpart:
|
||||||
yield self.check_username(localpart, guest_access_token=guest_access_token)
|
yield self.check_username(localpart, guest_access_token=guest_access_token)
|
||||||
@ -293,7 +294,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
for c in threepidCreds:
|
for c in threepidCreds:
|
||||||
logger.info("validating theeepidcred sid %s on id server %s",
|
logger.info("validating threepidcred sid %s on id server %s",
|
||||||
c['sid'], c['idServer'])
|
c['sid'], c['idServer'])
|
||||||
try:
|
try:
|
||||||
identity_handler = self.hs.get_handlers().identity_handler
|
identity_handler = self.hs.get_handlers().identity_handler
|
||||||
@ -307,6 +308,11 @@ class RegistrationHandler(BaseHandler):
|
|||||||
logger.info("got threepid with medium '%s' and address '%s'",
|
logger.info("got threepid with medium '%s' and address '%s'",
|
||||||
threepid['medium'], threepid['address'])
|
threepid['medium'], threepid['address'])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
|
||||||
|
raise RegistrationError(
|
||||||
|
403, "Third party identifier is not allowed"
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def bind_emails(self, user_id, threepidCreds):
|
def bind_emails(self, user_id, threepidCreds):
|
||||||
"""Links emails with a user ID and informs an identity server.
|
"""Links emails with a user ID and informs an identity server.
|
||||||
|
@ -203,7 +203,8 @@ class RoomListHandler(BaseHandler):
|
|||||||
if limit:
|
if limit:
|
||||||
step = limit + 1
|
step = limit + 1
|
||||||
else:
|
else:
|
||||||
step = len(rooms_to_scan)
|
# step cannot be zero
|
||||||
|
step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
|
||||||
|
|
||||||
chunk = []
|
chunk = []
|
||||||
for i in xrange(0, len(rooms_to_scan), step):
|
for i in xrange(0, len(rooms_to_scan), step):
|
||||||
|
@ -31,7 +31,7 @@ class SetPasswordHandler(BaseHandler):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_password(self, user_id, newpassword, requester=None):
|
def set_password(self, user_id, newpassword, requester=None):
|
||||||
password_hash = self._auth_handler.hash(newpassword)
|
password_hash = yield self._auth_handler.hash(newpassword)
|
||||||
|
|
||||||
except_device_id = requester.device_id if requester else None
|
except_device_id = requester.device_id if requester else None
|
||||||
except_access_token_id = requester.access_token_id if requester else None
|
except_access_token_id = requester.access_token_id if requester else None
|
||||||
|
@ -18,6 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
|
|||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
|
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
|
||||||
)
|
)
|
||||||
|
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||||
from synapse.util.logcontext import make_deferred_yieldable
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
from synapse.util import logcontext
|
from synapse.util import logcontext
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
@ -30,6 +31,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
|||||||
from twisted.web.client import (
|
from twisted.web.client import (
|
||||||
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
|
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
|
||||||
readBody, PartialDownloadError,
|
readBody, PartialDownloadError,
|
||||||
|
HTTPConnectionPool,
|
||||||
)
|
)
|
||||||
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
|
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
|
||||||
from twisted.web.http import PotentialDataLoss
|
from twisted.web.http import PotentialDataLoss
|
||||||
@ -64,13 +66,23 @@ class SimpleHttpClient(object):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
|
pool = HTTPConnectionPool(reactor)
|
||||||
|
|
||||||
|
# the pusher makes lots of concurrent SSL connections to sygnal, and
|
||||||
|
# tends to do so in batches, so we need to allow the pool to keep lots
|
||||||
|
# of idle connections around.
|
||||||
|
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
|
||||||
|
pool.cachedConnectionTimeout = 2 * 60
|
||||||
|
|
||||||
# The default context factory in Twisted 14.0.0 (which we require) is
|
# The default context factory in Twisted 14.0.0 (which we require) is
|
||||||
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
||||||
# 'like a browser'
|
# 'like a browser'
|
||||||
self.agent = Agent(
|
self.agent = Agent(
|
||||||
reactor,
|
reactor,
|
||||||
connectTimeout=15,
|
connectTimeout=15,
|
||||||
contextFactory=hs.get_http_client_context_factory()
|
contextFactory=hs.get_http_client_context_factory(),
|
||||||
|
pool=pool,
|
||||||
)
|
)
|
||||||
self.user_agent = hs.version_string
|
self.user_agent = hs.version_string
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
@ -357,8 +357,7 @@ def _get_hosts_for_srv_record(dns_client, host):
|
|||||||
def eb(res, record_type):
|
def eb(res, record_type):
|
||||||
if res.check(DNSNameError):
|
if res.check(DNSNameError):
|
||||||
return []
|
return []
|
||||||
logger.warn("Error looking up %s for %s: %s",
|
logger.warn("Error looking up %s for %s: %s", record_type, host, res)
|
||||||
record_type, host, res, res.value)
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
# no logcontexts here, so we can safely fire these off and gatherResults
|
# no logcontexts here, so we can safely fire these off and gatherResults
|
||||||
|
@ -27,7 +27,7 @@ import synapse.metrics
|
|||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
SynapseError, Codes, HttpResponseException,
|
SynapseError, Codes, HttpResponseException, FederationDeniedError,
|
||||||
)
|
)
|
||||||
|
|
||||||
from signedjson.sign import sign_json
|
from signedjson.sign import sign_json
|
||||||
@ -123,11 +123,22 @@ class MatrixFederationHttpClient(object):
|
|||||||
|
|
||||||
Fails with ``HTTPRequestException``: if we get an HTTP response
|
Fails with ``HTTPRequestException``: if we get an HTTP response
|
||||||
code >= 300.
|
code >= 300.
|
||||||
|
|
||||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
to retry this server.
|
to retry this server.
|
||||||
|
|
||||||
|
Fails with ``FederationDeniedError`` if this destination
|
||||||
|
is not on our federation whitelist
|
||||||
|
|
||||||
(May also fail with plenty of other Exceptions for things like DNS
|
(May also fail with plenty of other Exceptions for things like DNS
|
||||||
failures, connection failures, SSL failures.)
|
failures, connection failures, SSL failures.)
|
||||||
"""
|
"""
|
||||||
|
if (
|
||||||
|
self.hs.config.federation_domain_whitelist and
|
||||||
|
destination not in self.hs.config.federation_domain_whitelist
|
||||||
|
):
|
||||||
|
raise FederationDeniedError(destination)
|
||||||
|
|
||||||
limiter = yield synapse.util.retryutils.get_retry_limiter(
|
limiter = yield synapse.util.retryutils.get_retry_limiter(
|
||||||
destination,
|
destination,
|
||||||
self.clock,
|
self.clock,
|
||||||
@ -308,6 +319,9 @@ class MatrixFederationHttpClient(object):
|
|||||||
|
|
||||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
to retry this server.
|
to retry this server.
|
||||||
|
|
||||||
|
Fails with ``FederationDeniedError`` if this destination
|
||||||
|
is not on our federation whitelist
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not json_data_callback:
|
if not json_data_callback:
|
||||||
@ -368,6 +382,9 @@ class MatrixFederationHttpClient(object):
|
|||||||
|
|
||||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
to retry this server.
|
to retry this server.
|
||||||
|
|
||||||
|
Fails with ``FederationDeniedError`` if this destination
|
||||||
|
is not on our federation whitelist
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def body_callback(method, url_bytes, headers_dict):
|
def body_callback(method, url_bytes, headers_dict):
|
||||||
@ -422,6 +439,9 @@ class MatrixFederationHttpClient(object):
|
|||||||
|
|
||||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
to retry this server.
|
to retry this server.
|
||||||
|
|
||||||
|
Fails with ``FederationDeniedError`` if this destination
|
||||||
|
is not on our federation whitelist
|
||||||
"""
|
"""
|
||||||
logger.debug("get_json args: %s", args)
|
logger.debug("get_json args: %s", args)
|
||||||
|
|
||||||
@ -475,6 +495,9 @@ class MatrixFederationHttpClient(object):
|
|||||||
|
|
||||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
to retry this server.
|
to retry this server.
|
||||||
|
|
||||||
|
Fails with ``FederationDeniedError`` if this destination
|
||||||
|
is not on our federation whitelist
|
||||||
"""
|
"""
|
||||||
|
|
||||||
response = yield self._request(
|
response = yield self._request(
|
||||||
@ -518,6 +541,9 @@ class MatrixFederationHttpClient(object):
|
|||||||
|
|
||||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||||
to retry this server.
|
to retry this server.
|
||||||
|
|
||||||
|
Fails with ``FederationDeniedError`` if this destination
|
||||||
|
is not on our federation whitelist
|
||||||
"""
|
"""
|
||||||
|
|
||||||
encoded_args = {}
|
encoded_args = {}
|
||||||
|
@ -42,36 +42,70 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||||
|
|
||||||
incoming_requests_counter = metrics.register_counter(
|
# total number of responses served, split by method/servlet/tag
|
||||||
"requests",
|
response_count = metrics.register_counter(
|
||||||
|
"response_count",
|
||||||
labels=["method", "servlet", "tag"],
|
labels=["method", "servlet", "tag"],
|
||||||
|
alternative_names=(
|
||||||
|
# the following are all deprecated aliases for the same metric
|
||||||
|
metrics.name_prefix + x for x in (
|
||||||
|
"_requests",
|
||||||
|
"_response_time:count",
|
||||||
|
"_response_ru_utime:count",
|
||||||
|
"_response_ru_stime:count",
|
||||||
|
"_response_db_txn_count:count",
|
||||||
|
"_response_db_txn_duration:count",
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
outgoing_responses_counter = metrics.register_counter(
|
outgoing_responses_counter = metrics.register_counter(
|
||||||
"responses",
|
"responses",
|
||||||
labels=["method", "code"],
|
labels=["method", "code"],
|
||||||
)
|
)
|
||||||
|
|
||||||
response_timer = metrics.register_distribution(
|
response_timer = metrics.register_counter(
|
||||||
"response_time",
|
"response_time_seconds",
|
||||||
labels=["method", "servlet", "tag"]
|
labels=["method", "servlet", "tag"],
|
||||||
|
alternative_names=(
|
||||||
|
metrics.name_prefix + "_response_time:total",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
response_ru_utime = metrics.register_distribution(
|
response_ru_utime = metrics.register_counter(
|
||||||
"response_ru_utime", labels=["method", "servlet", "tag"]
|
"response_ru_utime_seconds", labels=["method", "servlet", "tag"],
|
||||||
|
alternative_names=(
|
||||||
|
metrics.name_prefix + "_response_ru_utime:total",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
response_ru_stime = metrics.register_distribution(
|
response_ru_stime = metrics.register_counter(
|
||||||
"response_ru_stime", labels=["method", "servlet", "tag"]
|
"response_ru_stime_seconds", labels=["method", "servlet", "tag"],
|
||||||
|
alternative_names=(
|
||||||
|
metrics.name_prefix + "_response_ru_stime:total",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
response_db_txn_count = metrics.register_distribution(
|
response_db_txn_count = metrics.register_counter(
|
||||||
"response_db_txn_count", labels=["method", "servlet", "tag"]
|
"response_db_txn_count", labels=["method", "servlet", "tag"],
|
||||||
|
alternative_names=(
|
||||||
|
metrics.name_prefix + "_response_db_txn_count:total",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
response_db_txn_duration = metrics.register_distribution(
|
# seconds spent waiting for db txns, excluding scheduling time, when processing
|
||||||
"response_db_txn_duration", labels=["method", "servlet", "tag"]
|
# this request
|
||||||
|
response_db_txn_duration = metrics.register_counter(
|
||||||
|
"response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
|
||||||
|
alternative_names=(
|
||||||
|
metrics.name_prefix + "_response_db_txn_duration:total",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# seconds spent waiting for a db connection, when processing this request
|
||||||
|
response_db_sched_duration = metrics.register_counter(
|
||||||
|
"response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
|
||||||
|
)
|
||||||
|
|
||||||
_next_request_id = 0
|
_next_request_id = 0
|
||||||
|
|
||||||
@ -107,6 +141,10 @@ def wrap_request_handler(request_handler, include_metrics=False):
|
|||||||
with LoggingContext(request_id) as request_context:
|
with LoggingContext(request_id) as request_context:
|
||||||
with Measure(self.clock, "wrapped_request_handler"):
|
with Measure(self.clock, "wrapped_request_handler"):
|
||||||
request_metrics = RequestMetrics()
|
request_metrics = RequestMetrics()
|
||||||
|
# we start the request metrics timer here with an initial stab
|
||||||
|
# at the servlet name. For most requests that name will be
|
||||||
|
# JsonResource (or a subclass), and JsonResource._async_render
|
||||||
|
# will update it once it picks a servlet.
|
||||||
request_metrics.start(self.clock, name=self.__class__.__name__)
|
request_metrics.start(self.clock, name=self.__class__.__name__)
|
||||||
|
|
||||||
request_context.request = request_id
|
request_context.request = request_id
|
||||||
@ -249,12 +287,23 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
if not m:
|
if not m:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# We found a match! Trigger callback and then return the
|
# We found a match! First update the metrics object to indicate
|
||||||
# returned response. We pass both the request and any
|
# which servlet is handling the request.
|
||||||
# matched groups from the regex to the callback.
|
|
||||||
|
|
||||||
callback = path_entry.callback
|
callback = path_entry.callback
|
||||||
|
|
||||||
|
servlet_instance = getattr(callback, "__self__", None)
|
||||||
|
if servlet_instance is not None:
|
||||||
|
servlet_classname = servlet_instance.__class__.__name__
|
||||||
|
else:
|
||||||
|
servlet_classname = "%r" % callback
|
||||||
|
|
||||||
|
request_metrics.name = servlet_classname
|
||||||
|
|
||||||
|
# Now trigger the callback. If it returns a response, we send it
|
||||||
|
# here. If it throws an exception, that is handled by the wrapper
|
||||||
|
# installed by @request_handler.
|
||||||
|
|
||||||
kwargs = intern_dict({
|
kwargs = intern_dict({
|
||||||
name: urllib.unquote(value).decode("UTF-8") if value else value
|
name: urllib.unquote(value).decode("UTF-8") if value else value
|
||||||
for name, value in m.groupdict().items()
|
for name, value in m.groupdict().items()
|
||||||
@ -265,30 +314,14 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
code, response = callback_return
|
code, response = callback_return
|
||||||
self._send_response(request, code, response)
|
self._send_response(request, code, response)
|
||||||
|
|
||||||
servlet_instance = getattr(callback, "__self__", None)
|
|
||||||
if servlet_instance is not None:
|
|
||||||
servlet_classname = servlet_instance.__class__.__name__
|
|
||||||
else:
|
|
||||||
servlet_classname = "%r" % callback
|
|
||||||
|
|
||||||
request_metrics.name = servlet_classname
|
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
||||||
|
request_metrics.name = self.__class__.__name__ + ".UnrecognizedRequest"
|
||||||
raise UnrecognizedRequestError()
|
raise UnrecognizedRequestError()
|
||||||
|
|
||||||
def _send_response(self, request, code, response_json_object,
|
def _send_response(self, request, code, response_json_object,
|
||||||
response_code_message=None):
|
response_code_message=None):
|
||||||
# could alternatively use request.notifyFinish() and flip a flag when
|
|
||||||
# the Deferred fires, but since the flag is RIGHT THERE it seems like
|
|
||||||
# a waste.
|
|
||||||
if request._disconnected:
|
|
||||||
logger.warn(
|
|
||||||
"Not sending response to request %s, already disconnected.",
|
|
||||||
request)
|
|
||||||
return
|
|
||||||
|
|
||||||
outgoing_responses_counter.inc(request.method, str(code))
|
outgoing_responses_counter.inc(request.method, str(code))
|
||||||
|
|
||||||
# TODO: Only enable CORS for the requests that need it.
|
# TODO: Only enable CORS for the requests that need it.
|
||||||
@ -322,7 +355,7 @@ class RequestMetrics(object):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
incoming_requests_counter.inc(request.method, self.name, tag)
|
response_count.inc(request.method, self.name, tag)
|
||||||
|
|
||||||
response_timer.inc_by(
|
response_timer.inc_by(
|
||||||
clock.time_msec() - self.start, request.method,
|
clock.time_msec() - self.start, request.method,
|
||||||
@ -341,7 +374,10 @@ class RequestMetrics(object):
|
|||||||
context.db_txn_count, request.method, self.name, tag
|
context.db_txn_count, request.method, self.name, tag
|
||||||
)
|
)
|
||||||
response_db_txn_duration.inc_by(
|
response_db_txn_duration.inc_by(
|
||||||
context.db_txn_duration, request.method, self.name, tag
|
context.db_txn_duration_ms / 1000., request.method, self.name, tag
|
||||||
|
)
|
||||||
|
response_db_sched_duration.inc_by(
|
||||||
|
context.db_sched_duration_ms / 1000., request.method, self.name, tag
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -364,6 +400,15 @@ class RootRedirect(resource.Resource):
|
|||||||
def respond_with_json(request, code, json_object, send_cors=False,
|
def respond_with_json(request, code, json_object, send_cors=False,
|
||||||
response_code_message=None, pretty_print=False,
|
response_code_message=None, pretty_print=False,
|
||||||
version_string="", canonical_json=True):
|
version_string="", canonical_json=True):
|
||||||
|
# could alternatively use request.notifyFinish() and flip a flag when
|
||||||
|
# the Deferred fires, but since the flag is RIGHT THERE it seems like
|
||||||
|
# a waste.
|
||||||
|
if request._disconnected:
|
||||||
|
logger.warn(
|
||||||
|
"Not sending response to request %s, already disconnected.",
|
||||||
|
request)
|
||||||
|
return
|
||||||
|
|
||||||
if pretty_print:
|
if pretty_print:
|
||||||
json_bytes = encode_pretty_printed_json(json_object) + "\n"
|
json_bytes = encode_pretty_printed_json(json_object) + "\n"
|
||||||
else:
|
else:
|
||||||
|
@ -66,14 +66,15 @@ class SynapseRequest(Request):
|
|||||||
context = LoggingContext.current_context()
|
context = LoggingContext.current_context()
|
||||||
ru_utime, ru_stime = context.get_resource_usage()
|
ru_utime, ru_stime = context.get_resource_usage()
|
||||||
db_txn_count = context.db_txn_count
|
db_txn_count = context.db_txn_count
|
||||||
db_txn_duration = context.db_txn_duration
|
db_txn_duration_ms = context.db_txn_duration_ms
|
||||||
|
db_sched_duration_ms = context.db_sched_duration_ms
|
||||||
except Exception:
|
except Exception:
|
||||||
ru_utime, ru_stime = (0, 0)
|
ru_utime, ru_stime = (0, 0)
|
||||||
db_txn_count, db_txn_duration = (0, 0)
|
db_txn_count, db_txn_duration_ms = (0, 0)
|
||||||
|
|
||||||
self.site.access_logger.info(
|
self.site.access_logger.info(
|
||||||
"%s - %s - {%s}"
|
"%s - %s - {%s}"
|
||||||
" Processed request: %dms (%dms, %dms) (%dms/%d)"
|
" Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
|
||||||
" %sB %s \"%s %s %s\" \"%s\"",
|
" %sB %s \"%s %s %s\" \"%s\"",
|
||||||
self.getClientIP(),
|
self.getClientIP(),
|
||||||
self.site.site_tag,
|
self.site.site_tag,
|
||||||
@ -81,7 +82,8 @@ class SynapseRequest(Request):
|
|||||||
int(time.time() * 1000) - self.start_time,
|
int(time.time() * 1000) - self.start_time,
|
||||||
int(ru_utime * 1000),
|
int(ru_utime * 1000),
|
||||||
int(ru_stime * 1000),
|
int(ru_stime * 1000),
|
||||||
int(db_txn_duration * 1000),
|
db_sched_duration_ms,
|
||||||
|
db_txn_duration_ms,
|
||||||
int(db_txn_count),
|
int(db_txn_count),
|
||||||
self.sentLength,
|
self.sentLength,
|
||||||
self.code,
|
self.code,
|
||||||
|
@ -146,10 +146,15 @@ def runUntilCurrentTimer(func):
|
|||||||
num_pending += 1
|
num_pending += 1
|
||||||
|
|
||||||
num_pending += len(reactor.threadCallQueue)
|
num_pending += len(reactor.threadCallQueue)
|
||||||
|
|
||||||
start = time.time() * 1000
|
start = time.time() * 1000
|
||||||
ret = func(*args, **kwargs)
|
ret = func(*args, **kwargs)
|
||||||
end = time.time() * 1000
|
end = time.time() * 1000
|
||||||
|
|
||||||
|
# record the amount of wallclock time spent running pending calls.
|
||||||
|
# This is a proxy for the actual amount of time between reactor polls,
|
||||||
|
# since about 25% of time is actually spent running things triggered by
|
||||||
|
# I/O events, but that is harder to capture without rewriting half the
|
||||||
|
# reactor.
|
||||||
tick_time.inc_by(end - start)
|
tick_time.inc_by(end - start)
|
||||||
pending_calls_metric.inc_by(num_pending)
|
pending_calls_metric.inc_by(num_pending)
|
||||||
|
|
||||||
|
@ -15,18 +15,38 @@
|
|||||||
|
|
||||||
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# TODO(paul): I can't believe Python doesn't have one of these
|
def flatten(items):
|
||||||
def map_concat(func, items):
|
"""Flatten a list of lists
|
||||||
# flatten a list-of-lists
|
|
||||||
return list(chain.from_iterable(map(func, items)))
|
Args:
|
||||||
|
items: iterable[iterable[X]]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[X]: flattened list
|
||||||
|
"""
|
||||||
|
return list(chain.from_iterable(items))
|
||||||
|
|
||||||
|
|
||||||
class BaseMetric(object):
|
class BaseMetric(object):
|
||||||
|
"""Base class for metrics which report a single value per label set
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, name, labels=[]):
|
def __init__(self, name, labels=[], alternative_names=[]):
|
||||||
self.name = name
|
"""
|
||||||
|
Args:
|
||||||
|
name (str): principal name for this metric
|
||||||
|
labels (list(str)): names of the labels which will be reported
|
||||||
|
for this metric
|
||||||
|
alternative_names (iterable(str)): list of alternative names for
|
||||||
|
this metric. This can be useful to provide a migration path
|
||||||
|
when renaming metrics.
|
||||||
|
"""
|
||||||
|
self._names = [name] + list(alternative_names)
|
||||||
self.labels = labels # OK not to clone as we never write it
|
self.labels = labels # OK not to clone as we never write it
|
||||||
|
|
||||||
def dimension(self):
|
def dimension(self):
|
||||||
@ -36,7 +56,7 @@ class BaseMetric(object):
|
|||||||
return not len(self.labels)
|
return not len(self.labels)
|
||||||
|
|
||||||
def _render_labelvalue(self, value):
|
def _render_labelvalue(self, value):
|
||||||
# TODO: some kind of value escape
|
# TODO: escape backslashes, quotes and newlines
|
||||||
return '"%s"' % (value)
|
return '"%s"' % (value)
|
||||||
|
|
||||||
def _render_key(self, values):
|
def _render_key(self, values):
|
||||||
@ -47,19 +67,60 @@ class BaseMetric(object):
|
|||||||
for k, v in zip(self.labels, values)])
|
for k, v in zip(self.labels, values)])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _render_for_labels(self, label_values, value):
|
||||||
|
"""Render this metric for a single set of labels
|
||||||
|
|
||||||
|
Args:
|
||||||
|
label_values (list[str]): values for each of the labels
|
||||||
|
value: value of the metric at with these labels
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
iterable[str]: rendered metric
|
||||||
|
"""
|
||||||
|
rendered_labels = self._render_key(label_values)
|
||||||
|
return (
|
||||||
|
"%s%s %.12g" % (name, rendered_labels, value)
|
||||||
|
for name in self._names
|
||||||
|
)
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
"""Render this metric
|
||||||
|
|
||||||
|
Each metric is rendered as:
|
||||||
|
|
||||||
|
name{label1="val1",label2="val2"} value
|
||||||
|
|
||||||
|
https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
iterable[str]: rendered metrics
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class CounterMetric(BaseMetric):
|
class CounterMetric(BaseMetric):
|
||||||
"""The simplest kind of metric; one that stores a monotonically-increasing
|
"""The simplest kind of metric; one that stores a monotonically-increasing
|
||||||
integer that counts events."""
|
value that counts events or running totals.
|
||||||
|
|
||||||
|
Example use cases for Counters:
|
||||||
|
- Number of requests processed
|
||||||
|
- Number of items that were inserted into a queue
|
||||||
|
- Total amount of data that a system has processed
|
||||||
|
Counters can only go up (and be reset when the process restarts).
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(CounterMetric, self).__init__(*args, **kwargs)
|
super(CounterMetric, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# dict[list[str]]: value for each set of label values. the keys are the
|
||||||
|
# label values, in the same order as the labels in self.labels.
|
||||||
|
#
|
||||||
|
# (if the metric is a scalar, the (single) key is the empty list).
|
||||||
self.counts = {}
|
self.counts = {}
|
||||||
|
|
||||||
# Scalar metrics are never empty
|
# Scalar metrics are never empty
|
||||||
if self.is_scalar():
|
if self.is_scalar():
|
||||||
self.counts[()] = 0
|
self.counts[()] = 0.
|
||||||
|
|
||||||
def inc_by(self, incr, *values):
|
def inc_by(self, incr, *values):
|
||||||
if len(values) != self.dimension():
|
if len(values) != self.dimension():
|
||||||
@ -77,11 +138,11 @@ class CounterMetric(BaseMetric):
|
|||||||
def inc(self, *values):
|
def inc(self, *values):
|
||||||
self.inc_by(1, *values)
|
self.inc_by(1, *values)
|
||||||
|
|
||||||
def render_item(self, k):
|
|
||||||
return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
|
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return map_concat(self.render_item, sorted(self.counts.keys()))
|
return flatten(
|
||||||
|
self._render_for_labels(k, self.counts[k])
|
||||||
|
for k in sorted(self.counts.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CallbackMetric(BaseMetric):
|
class CallbackMetric(BaseMetric):
|
||||||
@ -95,13 +156,19 @@ class CallbackMetric(BaseMetric):
|
|||||||
self.callback = callback
|
self.callback = callback
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
value = self.callback()
|
try:
|
||||||
|
value = self.callback()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to render %s", self.name)
|
||||||
|
return ["# FAILED to render " + self.name]
|
||||||
|
|
||||||
if self.is_scalar():
|
if self.is_scalar():
|
||||||
return ["%s %.12g" % (self.name, value)]
|
return list(self._render_for_labels([], value))
|
||||||
|
|
||||||
return ["%s%s %.12g" % (self.name, self._render_key(k), value[k])
|
return flatten(
|
||||||
for k in sorted(value.keys())]
|
self._render_for_labels(k, value[k])
|
||||||
|
for k in sorted(value.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DistributionMetric(object):
|
class DistributionMetric(object):
|
||||||
|
@ -13,21 +13,30 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
from synapse.push import PusherConfigException
|
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||||
|
|
||||||
import logging
|
|
||||||
import push_rule_evaluator
|
import push_rule_evaluator
|
||||||
import push_tools
|
import push_tools
|
||||||
|
import synapse
|
||||||
|
from synapse.push import PusherConfigException
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||||
|
|
||||||
|
http_push_processed_counter = metrics.register_counter(
|
||||||
|
"http_pushes_processed",
|
||||||
|
)
|
||||||
|
|
||||||
|
http_push_failed_counter = metrics.register_counter(
|
||||||
|
"http_pushes_failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HttpPusher(object):
|
class HttpPusher(object):
|
||||||
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
|
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
|
||||||
@ -152,9 +161,16 @@ class HttpPusher(object):
|
|||||||
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Processing %i unprocessed push actions for %s starting at "
|
||||||
|
"stream_ordering %s",
|
||||||
|
len(unprocessed), self.name, self.last_stream_ordering,
|
||||||
|
)
|
||||||
|
|
||||||
for push_action in unprocessed:
|
for push_action in unprocessed:
|
||||||
processed = yield self._process_one(push_action)
|
processed = yield self._process_one(push_action)
|
||||||
if processed:
|
if processed:
|
||||||
|
http_push_processed_counter.inc()
|
||||||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
||||||
self.last_stream_ordering = push_action['stream_ordering']
|
self.last_stream_ordering = push_action['stream_ordering']
|
||||||
yield self.store.update_pusher_last_stream_ordering_and_success(
|
yield self.store.update_pusher_last_stream_ordering_and_success(
|
||||||
@ -169,6 +185,7 @@ class HttpPusher(object):
|
|||||||
self.failing_since
|
self.failing_since
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
http_push_failed_counter.inc()
|
||||||
if not self.failing_since:
|
if not self.failing_since:
|
||||||
self.failing_since = self.clock.time_msec()
|
self.failing_since = self.clock.time_msec()
|
||||||
yield self.store.update_pusher_failing_since(
|
yield self.store.update_pusher_failing_since(
|
||||||
@ -316,7 +333,10 @@ class HttpPusher(object):
|
|||||||
try:
|
try:
|
||||||
resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
|
resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warn("Failed to push %s ", self.url)
|
logger.warn(
|
||||||
|
"Failed to push event %s to %s",
|
||||||
|
event.event_id, self.name, exc_info=True,
|
||||||
|
)
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
rejected = []
|
rejected = []
|
||||||
if 'rejected' in resp:
|
if 'rejected' in resp:
|
||||||
@ -325,7 +345,7 @@ class HttpPusher(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _send_badge(self, badge):
|
def _send_badge(self, badge):
|
||||||
logger.info("Sending updated badge count %d to %r", badge, self.user_id)
|
logger.info("Sending updated badge count %d to %s", badge, self.name)
|
||||||
d = {
|
d = {
|
||||||
'notification': {
|
'notification': {
|
||||||
'id': '',
|
'id': '',
|
||||||
@ -347,7 +367,10 @@ class HttpPusher(object):
|
|||||||
try:
|
try:
|
||||||
resp = yield self.http_client.post_json_get_json(self.url, d)
|
resp = yield self.http_client.post_json_get_json(self.url, d)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to push %s ", self.url)
|
logger.warn(
|
||||||
|
"Failed to send badge count to %s",
|
||||||
|
self.name, exc_info=True,
|
||||||
|
)
|
||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
rejected = []
|
rejected = []
|
||||||
if 'rejected' in resp:
|
if 'rejected' in resp:
|
||||||
|
@ -517,25 +517,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||||||
self.send_error("Wrong remote")
|
self.send_error("Wrong remote")
|
||||||
|
|
||||||
def on_RDATA(self, cmd):
|
def on_RDATA(self, cmd):
|
||||||
|
stream_name = cmd.stream_name
|
||||||
|
inbound_rdata_count.inc(stream_name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row)
|
row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"[%s] Failed to parse RDATA: %r %r",
|
"[%s] Failed to parse RDATA: %r %r",
|
||||||
self.id(), cmd.stream_name, cmd.row
|
self.id(), stream_name, cmd.row
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if cmd.token is None:
|
if cmd.token is None:
|
||||||
# I.e. this is part of a batch of updates for this stream. Batch
|
# I.e. this is part of a batch of updates for this stream. Batch
|
||||||
# until we get an update for the stream with a non None token
|
# until we get an update for the stream with a non None token
|
||||||
self.pending_batches.setdefault(cmd.stream_name, []).append(row)
|
self.pending_batches.setdefault(stream_name, []).append(row)
|
||||||
else:
|
else:
|
||||||
# Check if this is the last of a batch of updates
|
# Check if this is the last of a batch of updates
|
||||||
rows = self.pending_batches.pop(cmd.stream_name, [])
|
rows = self.pending_batches.pop(stream_name, [])
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
|
|
||||||
self.handler.on_rdata(cmd.stream_name, cmd.token, rows)
|
self.handler.on_rdata(stream_name, cmd.token, rows)
|
||||||
|
|
||||||
def on_POSITION(self, cmd):
|
def on_POSITION(self, cmd):
|
||||||
self.handler.on_position(cmd.stream_name, cmd.token)
|
self.handler.on_position(cmd.stream_name, cmd.token)
|
||||||
@ -644,3 +647,9 @@ metrics.register_callback(
|
|||||||
},
|
},
|
||||||
labels=["command", "name", "conn_id"],
|
labels=["command", "name", "conn_id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# number of updates received for each RDATA stream
|
||||||
|
inbound_rdata_count = metrics.register_counter(
|
||||||
|
"inbound_rdata_count",
|
||||||
|
labels=["stream_name"],
|
||||||
|
)
|
||||||
|
@ -191,19 +191,25 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
# convert threepid identifiers to user IDs
|
# convert threepid identifiers to user IDs
|
||||||
if identifier["type"] == "m.id.thirdparty":
|
if identifier["type"] == "m.id.thirdparty":
|
||||||
if 'medium' not in identifier or 'address' not in identifier:
|
address = identifier.get('address')
|
||||||
|
medium = identifier.get('medium')
|
||||||
|
|
||||||
|
if medium is None or address is None:
|
||||||
raise SynapseError(400, "Invalid thirdparty identifier")
|
raise SynapseError(400, "Invalid thirdparty identifier")
|
||||||
|
|
||||||
address = identifier['address']
|
if medium == 'email':
|
||||||
if identifier['medium'] == 'email':
|
|
||||||
# For emails, transform the address to lowercase.
|
# For emails, transform the address to lowercase.
|
||||||
# We store all email addreses as lowercase in the DB.
|
# We store all email addreses as lowercase in the DB.
|
||||||
# (See add_threepid in synapse/handlers/auth.py)
|
# (See add_threepid in synapse/handlers/auth.py)
|
||||||
address = address.lower()
|
address = address.lower()
|
||||||
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
identifier['medium'], address
|
medium, address,
|
||||||
)
|
)
|
||||||
if not user_id:
|
if not user_id:
|
||||||
|
logger.warn(
|
||||||
|
"unknown 3pid identifier medium %s, address %r",
|
||||||
|
medium, address,
|
||||||
|
)
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
identifier = {
|
identifier = {
|
||||||
|
@ -70,10 +70,15 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||||||
self.handlers = hs.get_handlers()
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
|
|
||||||
|
require_email = 'email' in self.hs.config.registrations_require_3pid
|
||||||
|
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
|
||||||
|
|
||||||
|
flows = []
|
||||||
if self.hs.config.enable_registration_captcha:
|
if self.hs.config.enable_registration_captcha:
|
||||||
return (
|
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||||
200,
|
if not require_msisdn:
|
||||||
{"flows": [
|
flows.extend([
|
||||||
{
|
{
|
||||||
"type": LoginType.RECAPTCHA,
|
"type": LoginType.RECAPTCHA,
|
||||||
"stages": [
|
"stages": [
|
||||||
@ -82,27 +87,34 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||||||
LoginType.PASSWORD
|
LoginType.PASSWORD
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
])
|
||||||
|
# only support 3PIDless registration if no 3PIDs are required
|
||||||
|
if not require_email and not require_msisdn:
|
||||||
|
flows.extend([
|
||||||
{
|
{
|
||||||
"type": LoginType.RECAPTCHA,
|
"type": LoginType.RECAPTCHA,
|
||||||
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
|
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
|
||||||
}
|
}
|
||||||
]}
|
])
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return (
|
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||||
200,
|
if require_email or not require_msisdn:
|
||||||
{"flows": [
|
flows.extend([
|
||||||
{
|
{
|
||||||
"type": LoginType.EMAIL_IDENTITY,
|
"type": LoginType.EMAIL_IDENTITY,
|
||||||
"stages": [
|
"stages": [
|
||||||
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
|
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
|
||||||
]
|
]
|
||||||
},
|
}
|
||||||
|
])
|
||||||
|
# only support 3PIDless registration if no 3PIDs are required
|
||||||
|
if not require_email and not require_msisdn:
|
||||||
|
flows.extend([
|
||||||
{
|
{
|
||||||
"type": LoginType.PASSWORD
|
"type": LoginType.PASSWORD
|
||||||
}
|
}
|
||||||
]}
|
])
|
||||||
)
|
return (200, {"flows": flows})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
|
@ -195,15 +195,20 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
|||||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
event_dict = {
|
||||||
|
"type": event_type,
|
||||||
|
"content": content,
|
||||||
|
"room_id": room_id,
|
||||||
|
"sender": requester.user.to_string(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'ts' in request.args and requester.app_service:
|
||||||
|
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
|
||||||
|
|
||||||
msg_handler = self.handlers.message_handler
|
msg_handler = self.handlers.message_handler
|
||||||
event = yield msg_handler.create_and_send_nonmember_event(
|
event = yield msg_handler.create_and_send_nonmember_event(
|
||||||
requester,
|
requester,
|
||||||
{
|
event_dict,
|
||||||
"type": event_type,
|
|
||||||
"content": content,
|
|
||||||
"room_id": room_id,
|
|
||||||
"sender": requester.user.to_string(),
|
|
||||||
},
|
|
||||||
txn_id=txn_id,
|
txn_id=txn_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -487,13 +492,35 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
|
|||||||
defer.returnValue((200, content))
|
defer.returnValue((200, content))
|
||||||
|
|
||||||
|
|
||||||
class RoomEventContext(ClientV1RestServlet):
|
class RoomEventServlet(ClientV1RestServlet):
|
||||||
|
PATTERNS = client_path_patterns(
|
||||||
|
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(RoomEventServlet, self).__init__(hs)
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.event_handler = hs.get_event_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, room_id, event_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
event = yield self.event_handler.get_event(requester.user, event_id)
|
||||||
|
|
||||||
|
time_now = self.clock.time_msec()
|
||||||
|
if event:
|
||||||
|
defer.returnValue((200, serialize_event(event, time_now)))
|
||||||
|
else:
|
||||||
|
defer.returnValue((404, "Event not found."))
|
||||||
|
|
||||||
|
|
||||||
|
class RoomEventContextServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns(
|
PATTERNS = client_path_patterns(
|
||||||
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
|
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(RoomEventContext, self).__init__(hs)
|
super(RoomEventContextServlet, self).__init__(hs)
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.handlers = hs.get_handlers()
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@ -803,4 +830,5 @@ def register_servlets(hs, http_server):
|
|||||||
RoomTypingRestServlet(hs).register(http_server)
|
RoomTypingRestServlet(hs).register(http_server)
|
||||||
SearchRestServlet(hs).register(http_server)
|
SearchRestServlet(hs).register(http_server)
|
||||||
JoinedRoomsRestServlet(hs).register(http_server)
|
JoinedRoomsRestServlet(hs).register(http_server)
|
||||||
RoomEventContext(hs).register(http_server)
|
RoomEventServlet(hs).register(http_server)
|
||||||
|
RoomEventContextServlet(hs).register(http_server)
|
||||||
|
@ -26,6 +26,7 @@ from synapse.http.servlet import (
|
|||||||
)
|
)
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
from synapse.util.threepids import check_3pid_allowed
|
||||||
from ._base import client_v2_patterns, interactive_auth_handler
|
from ._base import client_v2_patterns, interactive_auth_handler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -47,6 +48,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
|||||||
'id_server', 'client_secret', 'email', 'send_attempt'
|
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||||
])
|
])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
'email', body['email']
|
'email', body['email']
|
||||||
)
|
)
|
||||||
@ -78,6 +84,11 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
|
|||||||
|
|
||||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||||
'msisdn', msisdn
|
'msisdn', msisdn
|
||||||
)
|
)
|
||||||
@ -217,6 +228,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
|||||||
if absent:
|
if absent:
|
||||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||||
'email', body['email']
|
'email', body['email']
|
||||||
)
|
)
|
||||||
@ -255,6 +271,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
|||||||
|
|
||||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||||
'msisdn', msisdn
|
'msisdn', msisdn
|
||||||
)
|
)
|
||||||
|
@ -26,6 +26,7 @@ from synapse.http.servlet import (
|
|||||||
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
|
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
|
||||||
)
|
)
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
from synapse.util.threepids import check_3pid_allowed
|
||||||
|
|
||||||
from ._base import client_v2_patterns, interactive_auth_handler
|
from ._base import client_v2_patterns, interactive_auth_handler
|
||||||
|
|
||||||
@ -70,6 +71,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
|||||||
'id_server', 'client_secret', 'email', 'send_attempt'
|
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||||
])
|
])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
'email', body['email']
|
'email', body['email']
|
||||||
)
|
)
|
||||||
@ -105,6 +111,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
|||||||
|
|
||||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
'msisdn', msisdn
|
'msisdn', msisdn
|
||||||
)
|
)
|
||||||
@ -305,31 +316,67 @@ class RegisterRestServlet(RestServlet):
|
|||||||
if 'x_show_msisdn' in body and body['x_show_msisdn']:
|
if 'x_show_msisdn' in body and body['x_show_msisdn']:
|
||||||
show_msisdn = True
|
show_msisdn = True
|
||||||
|
|
||||||
|
# FIXME: need a better error than "no auth flow found" for scenarios
|
||||||
|
# where we required 3PID for registration but the user didn't give one
|
||||||
|
require_email = 'email' in self.hs.config.registrations_require_3pid
|
||||||
|
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
|
||||||
|
|
||||||
|
flows = []
|
||||||
if self.hs.config.enable_registration_captcha:
|
if self.hs.config.enable_registration_captcha:
|
||||||
flows = [
|
# only support 3PIDless registration if no 3PIDs are required
|
||||||
[LoginType.RECAPTCHA],
|
if not require_email and not require_msisdn:
|
||||||
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
flows.extend([[LoginType.RECAPTCHA]])
|
||||||
]
|
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||||
|
if not require_msisdn:
|
||||||
|
flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
|
||||||
|
|
||||||
if show_msisdn:
|
if show_msisdn:
|
||||||
|
# only support the MSISDN-only flow if we don't require email 3PIDs
|
||||||
|
if not require_email:
|
||||||
|
flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
|
||||||
|
# always let users provide both MSISDN & email
|
||||||
flows.extend([
|
flows.extend([
|
||||||
[LoginType.MSISDN, LoginType.RECAPTCHA],
|
|
||||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
flows = [
|
# only support 3PIDless registration if no 3PIDs are required
|
||||||
[LoginType.DUMMY],
|
if not require_email and not require_msisdn:
|
||||||
[LoginType.EMAIL_IDENTITY],
|
flows.extend([[LoginType.DUMMY]])
|
||||||
]
|
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||||
|
if not require_msisdn:
|
||||||
|
flows.extend([[LoginType.EMAIL_IDENTITY]])
|
||||||
|
|
||||||
if show_msisdn:
|
if show_msisdn:
|
||||||
|
# only support the MSISDN-only flow if we don't require email 3PIDs
|
||||||
|
if not require_email or require_msisdn:
|
||||||
|
flows.extend([[LoginType.MSISDN]])
|
||||||
|
# always let users provide both MSISDN & email
|
||||||
flows.extend([
|
flows.extend([
|
||||||
[LoginType.MSISDN],
|
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
|
||||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
|
|
||||||
])
|
])
|
||||||
|
|
||||||
auth_result, params, session_id = yield self.auth_handler.check_auth(
|
auth_result, params, session_id = yield self.auth_handler.check_auth(
|
||||||
flows, body, self.hs.get_ip_from_request(request)
|
flows, body, self.hs.get_ip_from_request(request)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check that we're not trying to register a denied 3pid.
|
||||||
|
#
|
||||||
|
# the user-facing checks will probably already have happened in
|
||||||
|
# /register/email/requestToken when we requested a 3pid, but that's not
|
||||||
|
# guaranteed.
|
||||||
|
|
||||||
|
if auth_result:
|
||||||
|
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
|
||||||
|
if login_type in auth_result:
|
||||||
|
medium = auth_result[login_type]['medium']
|
||||||
|
address = auth_result[login_type]['address']
|
||||||
|
|
||||||
|
if not check_3pid_allowed(self.hs, medium, address):
|
||||||
|
raise SynapseError(
|
||||||
|
403, "Third party identifier is not allowed",
|
||||||
|
Codes.THREEPID_DENIED,
|
||||||
|
)
|
||||||
|
|
||||||
if registered_user_id is not None:
|
if registered_user_id is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Already registered user ID %r for this session",
|
"Already registered user ID %r for this session",
|
||||||
|
@ -93,6 +93,7 @@ class RemoteKey(Resource):
|
|||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.version_string = hs.version_string
|
self.version_string = hs.version_string
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||||
|
|
||||||
def render_GET(self, request):
|
def render_GET(self, request):
|
||||||
self.async_render_GET(request)
|
self.async_render_GET(request)
|
||||||
@ -137,6 +138,13 @@ class RemoteKey(Resource):
|
|||||||
logger.info("Handling query for keys %r", query)
|
logger.info("Handling query for keys %r", query)
|
||||||
store_queries = []
|
store_queries = []
|
||||||
for server_name, key_ids in query.items():
|
for server_name, key_ids in query.items():
|
||||||
|
if (
|
||||||
|
self.federation_domain_whitelist is not None and
|
||||||
|
server_name not in self.federation_domain_whitelist
|
||||||
|
):
|
||||||
|
logger.debug("Federation denied with %s", server_name)
|
||||||
|
continue
|
||||||
|
|
||||||
if not key_ids:
|
if not key_ids:
|
||||||
key_ids = (None,)
|
key_ids = (None,)
|
||||||
for key_id in key_ids:
|
for key_id in key_ids:
|
||||||
|
@ -70,38 +70,11 @@ def respond_with_file(request, media_type, file_path,
|
|||||||
logger.debug("Responding with %r", file_path)
|
logger.debug("Responding with %r", file_path)
|
||||||
|
|
||||||
if os.path.isfile(file_path):
|
if os.path.isfile(file_path):
|
||||||
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
|
|
||||||
if upload_name:
|
|
||||||
if is_ascii(upload_name):
|
|
||||||
request.setHeader(
|
|
||||||
b"Content-Disposition",
|
|
||||||
b"inline; filename=%s" % (
|
|
||||||
urllib.quote(upload_name.encode("utf-8")),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
request.setHeader(
|
|
||||||
b"Content-Disposition",
|
|
||||||
b"inline; filename*=utf-8''%s" % (
|
|
||||||
urllib.quote(upload_name.encode("utf-8")),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# cache for at least a day.
|
|
||||||
# XXX: we might want to turn this off for data we don't want to
|
|
||||||
# recommend caching as it's sensitive or private - or at least
|
|
||||||
# select private. don't bother setting Expires as all our
|
|
||||||
# clients are smart enough to be happy with Cache-Control
|
|
||||||
request.setHeader(
|
|
||||||
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
|
|
||||||
)
|
|
||||||
if file_size is None:
|
if file_size is None:
|
||||||
stat = os.stat(file_path)
|
stat = os.stat(file_path)
|
||||||
file_size = stat.st_size
|
file_size = stat.st_size
|
||||||
|
|
||||||
request.setHeader(
|
add_file_headers(request, media_type, file_size, upload_name)
|
||||||
b"Content-Length", b"%d" % (file_size,)
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
yield logcontext.make_deferred_yieldable(
|
yield logcontext.make_deferred_yieldable(
|
||||||
@ -111,3 +84,118 @@ def respond_with_file(request, media_type, file_path,
|
|||||||
finish_request(request)
|
finish_request(request)
|
||||||
else:
|
else:
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
|
|
||||||
|
|
||||||
|
def add_file_headers(request, media_type, file_size, upload_name):
|
||||||
|
"""Adds the correct response headers in preparation for responding with the
|
||||||
|
media.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (twisted.web.http.Request)
|
||||||
|
media_type (str): The media/content type.
|
||||||
|
file_size (int): Size in bytes of the media, if known.
|
||||||
|
upload_name (str): The name of the requested file, if any.
|
||||||
|
"""
|
||||||
|
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
|
||||||
|
if upload_name:
|
||||||
|
if is_ascii(upload_name):
|
||||||
|
request.setHeader(
|
||||||
|
b"Content-Disposition",
|
||||||
|
b"inline; filename=%s" % (
|
||||||
|
urllib.quote(upload_name.encode("utf-8")),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
request.setHeader(
|
||||||
|
b"Content-Disposition",
|
||||||
|
b"inline; filename*=utf-8''%s" % (
|
||||||
|
urllib.quote(upload_name.encode("utf-8")),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# cache for at least a day.
|
||||||
|
# XXX: we might want to turn this off for data we don't want to
|
||||||
|
# recommend caching as it's sensitive or private - or at least
|
||||||
|
# select private. don't bother setting Expires as all our
|
||||||
|
# clients are smart enough to be happy with Cache-Control
|
||||||
|
request.setHeader(
|
||||||
|
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
|
||||||
|
)
|
||||||
|
|
||||||
|
request.setHeader(
|
||||||
|
b"Content-Length", b"%d" % (file_size,)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
|
||||||
|
"""Responds to the request with given responder. If responder is None then
|
||||||
|
returns 404.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (twisted.web.http.Request)
|
||||||
|
responder (Responder|None)
|
||||||
|
media_type (str): The media/content type.
|
||||||
|
file_size (int|None): Size in bytes of the media. If not known it should be None
|
||||||
|
upload_name (str|None): The name of the requested file, if any.
|
||||||
|
"""
|
||||||
|
if not responder:
|
||||||
|
respond_404(request)
|
||||||
|
return
|
||||||
|
|
||||||
|
add_file_headers(request, media_type, file_size, upload_name)
|
||||||
|
with responder:
|
||||||
|
yield responder.write_to_consumer(request)
|
||||||
|
finish_request(request)
|
||||||
|
|
||||||
|
|
||||||
|
class Responder(object):
|
||||||
|
"""Represents a response that can be streamed to the requester.
|
||||||
|
|
||||||
|
Responder is a context manager which *must* be used, so that any resources
|
||||||
|
held can be cleaned up.
|
||||||
|
"""
|
||||||
|
def write_to_consumer(self, consumer):
|
||||||
|
"""Stream response into consumer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
consumer (IConsumer)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred: Resolves once the response has finished being written
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FileInfo(object):
|
||||||
|
"""Details about a requested/uploaded file.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
server_name (str): The server name where the media originated from,
|
||||||
|
or None if local.
|
||||||
|
file_id (str): The local ID of the file. For local files this is the
|
||||||
|
same as the media_id
|
||||||
|
url_cache (bool): If the file is for the url preview cache
|
||||||
|
thumbnail (bool): Whether the file is a thumbnail or not.
|
||||||
|
thumbnail_width (int)
|
||||||
|
thumbnail_height (int)
|
||||||
|
thumbnail_method (str)
|
||||||
|
thumbnail_type (str): Content type of thumbnail, e.g. image/png
|
||||||
|
"""
|
||||||
|
def __init__(self, server_name, file_id, url_cache=False,
|
||||||
|
thumbnail=False, thumbnail_width=None, thumbnail_height=None,
|
||||||
|
thumbnail_method=None, thumbnail_type=None):
|
||||||
|
self.server_name = server_name
|
||||||
|
self.file_id = file_id
|
||||||
|
self.url_cache = url_cache
|
||||||
|
self.thumbnail = thumbnail
|
||||||
|
self.thumbnail_width = thumbnail_width
|
||||||
|
self.thumbnail_height = thumbnail_height
|
||||||
|
self.thumbnail_method = thumbnail_method
|
||||||
|
self.thumbnail_type = thumbnail_type
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import synapse.http.servlet
|
import synapse.http.servlet
|
||||||
|
|
||||||
from ._base import parse_media_id, respond_with_file, respond_404
|
from ._base import parse_media_id, respond_404
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from synapse.http.server import request_handler, set_cors_headers
|
from synapse.http.server import request_handler, set_cors_headers
|
||||||
|
|
||||||
@ -32,12 +32,12 @@ class DownloadResource(Resource):
|
|||||||
def __init__(self, hs, media_repo):
|
def __init__(self, hs, media_repo):
|
||||||
Resource.__init__(self)
|
Resource.__init__(self)
|
||||||
|
|
||||||
self.filepaths = media_repo.filepaths
|
|
||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.store = hs.get_datastore()
|
|
||||||
self.version_string = hs.version_string
|
# Both of these are expected by @request_handler()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self.version_string = hs.version_string
|
||||||
|
|
||||||
def render_GET(self, request):
|
def render_GET(self, request):
|
||||||
self._async_render_GET(request)
|
self._async_render_GET(request)
|
||||||
@ -57,59 +57,16 @@ class DownloadResource(Resource):
|
|||||||
)
|
)
|
||||||
server_name, media_id, name = parse_media_id(request)
|
server_name, media_id, name = parse_media_id(request)
|
||||||
if server_name == self.server_name:
|
if server_name == self.server_name:
|
||||||
yield self._respond_local_file(request, media_id, name)
|
yield self.media_repo.get_local_media(request, media_id, name)
|
||||||
else:
|
else:
|
||||||
yield self._respond_remote_file(
|
allow_remote = synapse.http.servlet.parse_boolean(
|
||||||
request, server_name, media_id, name
|
request, "allow_remote", default=True)
|
||||||
)
|
if not allow_remote:
|
||||||
|
logger.info(
|
||||||
|
"Rejecting request for remote media %s/%s due to allow_remote",
|
||||||
|
server_name, media_id,
|
||||||
|
)
|
||||||
|
respond_404(request)
|
||||||
|
return
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
yield self.media_repo.get_remote_media(request, server_name, media_id, name)
|
||||||
def _respond_local_file(self, request, media_id, name):
|
|
||||||
media_info = yield self.store.get_local_media(media_id)
|
|
||||||
if not media_info or media_info["quarantined_by"]:
|
|
||||||
respond_404(request)
|
|
||||||
return
|
|
||||||
|
|
||||||
media_type = media_info["media_type"]
|
|
||||||
media_length = media_info["media_length"]
|
|
||||||
upload_name = name if name else media_info["upload_name"]
|
|
||||||
if media_info["url_cache"]:
|
|
||||||
# TODO: Check the file still exists, if it doesn't we can redownload
|
|
||||||
# it from the url `media_info["url_cache"]`
|
|
||||||
file_path = self.filepaths.url_cache_filepath(media_id)
|
|
||||||
else:
|
|
||||||
file_path = self.filepaths.local_media_filepath(media_id)
|
|
||||||
|
|
||||||
yield respond_with_file(
|
|
||||||
request, media_type, file_path, media_length,
|
|
||||||
upload_name=upload_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _respond_remote_file(self, request, server_name, media_id, name):
|
|
||||||
# don't forward requests for remote media if allow_remote is false
|
|
||||||
allow_remote = synapse.http.servlet.parse_boolean(
|
|
||||||
request, "allow_remote", default=True)
|
|
||||||
if not allow_remote:
|
|
||||||
logger.info(
|
|
||||||
"Rejecting request for remote media %s/%s due to allow_remote",
|
|
||||||
server_name, media_id,
|
|
||||||
)
|
|
||||||
respond_404(request)
|
|
||||||
return
|
|
||||||
|
|
||||||
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
|
|
||||||
|
|
||||||
media_type = media_info["media_type"]
|
|
||||||
media_length = media_info["media_length"]
|
|
||||||
filesystem_id = media_info["filesystem_id"]
|
|
||||||
upload_name = name if name else media_info["upload_name"]
|
|
||||||
|
|
||||||
file_path = self.filepaths.remote_media_filepath(
|
|
||||||
server_name, filesystem_id
|
|
||||||
)
|
|
||||||
|
|
||||||
yield respond_with_file(
|
|
||||||
request, media_type, file_path, media_length,
|
|
||||||
upload_name=upload_name,
|
|
||||||
)
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -18,6 +19,7 @@ import twisted.internet.error
|
|||||||
import twisted.web.http
|
import twisted.web.http
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
|
from ._base import respond_404, FileInfo, respond_with_responder
|
||||||
from .upload_resource import UploadResource
|
from .upload_resource import UploadResource
|
||||||
from .download_resource import DownloadResource
|
from .download_resource import DownloadResource
|
||||||
from .thumbnail_resource import ThumbnailResource
|
from .thumbnail_resource import ThumbnailResource
|
||||||
@ -25,15 +27,18 @@ from .identicon_resource import IdenticonResource
|
|||||||
from .preview_url_resource import PreviewUrlResource
|
from .preview_url_resource import PreviewUrlResource
|
||||||
from .filepath import MediaFilePaths
|
from .filepath import MediaFilePaths
|
||||||
from .thumbnailer import Thumbnailer
|
from .thumbnailer import Thumbnailer
|
||||||
|
from .storage_provider import StorageProviderWrapper
|
||||||
|
from .media_storage import MediaStorage
|
||||||
|
|
||||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
from synapse.api.errors import SynapseError, HttpResponseException, \
|
from synapse.api.errors import (
|
||||||
NotFoundError
|
SynapseError, HttpResponseException, NotFoundError, FederationDeniedError,
|
||||||
|
)
|
||||||
|
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.stringutils import is_ascii
|
from synapse.util.stringutils import is_ascii
|
||||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -47,7 +52,7 @@ import urlparse
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
|
UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
|
||||||
|
|
||||||
|
|
||||||
class MediaRepository(object):
|
class MediaRepository(object):
|
||||||
@ -63,96 +68,62 @@ class MediaRepository(object):
|
|||||||
self.primary_base_path = hs.config.media_store_path
|
self.primary_base_path = hs.config.media_store_path
|
||||||
self.filepaths = MediaFilePaths(self.primary_base_path)
|
self.filepaths = MediaFilePaths(self.primary_base_path)
|
||||||
|
|
||||||
self.backup_base_path = hs.config.backup_media_store_path
|
|
||||||
|
|
||||||
self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
|
|
||||||
|
|
||||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||||
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||||
|
|
||||||
self.remote_media_linearizer = Linearizer(name="media_remote")
|
self.remote_media_linearizer = Linearizer(name="media_remote")
|
||||||
|
|
||||||
self.recently_accessed_remotes = set()
|
self.recently_accessed_remotes = set()
|
||||||
|
self.recently_accessed_locals = set()
|
||||||
|
|
||||||
|
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||||
|
|
||||||
|
# List of StorageProviders where we should search for media and
|
||||||
|
# potentially upload to.
|
||||||
|
storage_providers = []
|
||||||
|
|
||||||
|
for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
|
||||||
|
backend = clz(hs, provider_config)
|
||||||
|
provider = StorageProviderWrapper(
|
||||||
|
backend,
|
||||||
|
store_local=wrapper_config.store_local,
|
||||||
|
store_remote=wrapper_config.store_remote,
|
||||||
|
store_synchronous=wrapper_config.store_synchronous,
|
||||||
|
)
|
||||||
|
storage_providers.append(provider)
|
||||||
|
|
||||||
|
self.media_storage = MediaStorage(
|
||||||
|
self.primary_base_path, self.filepaths, storage_providers,
|
||||||
|
)
|
||||||
|
|
||||||
self.clock.looping_call(
|
self.clock.looping_call(
|
||||||
self._update_recently_accessed_remotes,
|
self._update_recently_accessed,
|
||||||
UPDATE_RECENTLY_ACCESSED_REMOTES_TS
|
UPDATE_RECENTLY_ACCESSED_TS,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _update_recently_accessed_remotes(self):
|
def _update_recently_accessed(self):
|
||||||
media = self.recently_accessed_remotes
|
remote_media = self.recently_accessed_remotes
|
||||||
self.recently_accessed_remotes = set()
|
self.recently_accessed_remotes = set()
|
||||||
|
|
||||||
|
local_media = self.recently_accessed_locals
|
||||||
|
self.recently_accessed_locals = set()
|
||||||
|
|
||||||
yield self.store.update_cached_last_access_time(
|
yield self.store.update_cached_last_access_time(
|
||||||
media, self.clock.time_msec()
|
local_media, remote_media, self.clock.time_msec()
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
def mark_recently_accessed(self, server_name, media_id):
|
||||||
def _makedirs(filepath):
|
"""Mark the given media as recently accessed.
|
||||||
dirname = os.path.dirname(filepath)
|
|
||||||
if not os.path.exists(dirname):
|
|
||||||
os.makedirs(dirname)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _write_file_synchronously(source, fname):
|
|
||||||
"""Write `source` to the path `fname` synchronously. Should be called
|
|
||||||
from a thread.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source: A file like object to be written
|
server_name (str|None): Origin server of media, or None if local
|
||||||
fname (str): Path to write to
|
media_id (str): The media ID of the content
|
||||||
"""
|
"""
|
||||||
MediaRepository._makedirs(fname)
|
if server_name:
|
||||||
source.seek(0) # Ensure we read from the start of the file
|
self.recently_accessed_remotes.add((server_name, media_id))
|
||||||
with open(fname, "wb") as f:
|
else:
|
||||||
shutil.copyfileobj(source, f)
|
self.recently_accessed_locals.add(media_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def write_to_file_and_backup(self, source, path):
|
|
||||||
"""Write `source` to the on disk media store, and also the backup store
|
|
||||||
if configured.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source: A file like object that should be written
|
|
||||||
path (str): Relative path to write file to
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred[str]: the file path written to in the primary media store
|
|
||||||
"""
|
|
||||||
fname = os.path.join(self.primary_base_path, path)
|
|
||||||
|
|
||||||
# Write to the main repository
|
|
||||||
yield make_deferred_yieldable(threads.deferToThread(
|
|
||||||
self._write_file_synchronously, source, fname,
|
|
||||||
))
|
|
||||||
|
|
||||||
# Write to backup repository
|
|
||||||
yield self.copy_to_backup(path)
|
|
||||||
|
|
||||||
defer.returnValue(fname)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def copy_to_backup(self, path):
|
|
||||||
"""Copy a file from the primary to backup media store, if configured.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path(str): Relative path to write file to
|
|
||||||
"""
|
|
||||||
if self.backup_base_path:
|
|
||||||
primary_fname = os.path.join(self.primary_base_path, path)
|
|
||||||
backup_fname = os.path.join(self.backup_base_path, path)
|
|
||||||
|
|
||||||
# We can either wait for successful writing to the backup repository
|
|
||||||
# or write in the background and immediately return
|
|
||||||
if self.synchronous_backup_media_store:
|
|
||||||
yield make_deferred_yieldable(threads.deferToThread(
|
|
||||||
shutil.copyfile, primary_fname, backup_fname,
|
|
||||||
))
|
|
||||||
else:
|
|
||||||
preserve_fn(threads.deferToThread)(
|
|
||||||
shutil.copyfile, primary_fname, backup_fname,
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def create_content(self, media_type, upload_name, content, content_length,
|
def create_content(self, media_type, upload_name, content, content_length,
|
||||||
@ -171,10 +142,13 @@ class MediaRepository(object):
|
|||||||
"""
|
"""
|
||||||
media_id = random_string(24)
|
media_id = random_string(24)
|
||||||
|
|
||||||
fname = yield self.write_to_file_and_backup(
|
file_info = FileInfo(
|
||||||
content, self.filepaths.local_media_filepath_rel(media_id)
|
server_name=None,
|
||||||
|
file_id=media_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fname = yield self.media_storage.store_file(content, file_info)
|
||||||
|
|
||||||
logger.info("Stored local media in file %r", fname)
|
logger.info("Stored local media in file %r", fname)
|
||||||
|
|
||||||
yield self.store.store_local_media(
|
yield self.store.store_local_media(
|
||||||
@ -185,134 +159,275 @@ class MediaRepository(object):
|
|||||||
media_length=content_length,
|
media_length=content_length,
|
||||||
user_id=auth_user,
|
user_id=auth_user,
|
||||||
)
|
)
|
||||||
media_info = {
|
|
||||||
"media_type": media_type,
|
|
||||||
"media_length": content_length,
|
|
||||||
}
|
|
||||||
|
|
||||||
yield self._generate_thumbnails(None, media_id, media_info)
|
yield self._generate_thumbnails(
|
||||||
|
None, media_id, media_id, media_type,
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_remote_media(self, server_name, media_id):
|
def get_local_media(self, request, media_id, name):
|
||||||
|
"""Responds to reqests for local media, if exists, or returns 404.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request(twisted.web.http.Request)
|
||||||
|
media_id (str): The media ID of the content. (This is the same as
|
||||||
|
the file_id for local content.)
|
||||||
|
name (str|None): Optional name that, if specified, will be used as
|
||||||
|
the filename in the Content-Disposition header of the response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred: Resolves once a response has successfully been written
|
||||||
|
to request
|
||||||
|
"""
|
||||||
|
media_info = yield self.store.get_local_media(media_id)
|
||||||
|
if not media_info or media_info["quarantined_by"]:
|
||||||
|
respond_404(request)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.mark_recently_accessed(None, media_id)
|
||||||
|
|
||||||
|
media_type = media_info["media_type"]
|
||||||
|
media_length = media_info["media_length"]
|
||||||
|
upload_name = name if name else media_info["upload_name"]
|
||||||
|
url_cache = media_info["url_cache"]
|
||||||
|
|
||||||
|
file_info = FileInfo(
|
||||||
|
None, media_id,
|
||||||
|
url_cache=url_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
responder = yield self.media_storage.fetch_media(file_info)
|
||||||
|
yield respond_with_responder(
|
||||||
|
request, responder, media_type, media_length, upload_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_remote_media(self, request, server_name, media_id, name):
|
||||||
|
"""Respond to requests for remote media.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request(twisted.web.http.Request)
|
||||||
|
server_name (str): Remote server_name where the media originated.
|
||||||
|
media_id (str): The media ID of the content (as defined by the
|
||||||
|
remote server).
|
||||||
|
name (str|None): Optional name that, if specified, will be used as
|
||||||
|
the filename in the Content-Disposition header of the response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred: Resolves once a response has successfully been written
|
||||||
|
to request
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
self.federation_domain_whitelist is not None and
|
||||||
|
server_name not in self.federation_domain_whitelist
|
||||||
|
):
|
||||||
|
raise FederationDeniedError(server_name)
|
||||||
|
|
||||||
|
self.mark_recently_accessed(server_name, media_id)
|
||||||
|
|
||||||
|
# We linearize here to ensure that we don't try and download remote
|
||||||
|
# media multiple times concurrently
|
||||||
key = (server_name, media_id)
|
key = (server_name, media_id)
|
||||||
with (yield self.remote_media_linearizer.queue(key)):
|
with (yield self.remote_media_linearizer.queue(key)):
|
||||||
media_info = yield self._get_remote_media_impl(server_name, media_id)
|
responder, media_info = yield self._get_remote_media_impl(
|
||||||
|
server_name, media_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We deliberately stream the file outside the lock
|
||||||
|
if responder:
|
||||||
|
media_type = media_info["media_type"]
|
||||||
|
media_length = media_info["media_length"]
|
||||||
|
upload_name = name if name else media_info["upload_name"]
|
||||||
|
yield respond_with_responder(
|
||||||
|
request, responder, media_type, media_length, upload_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
respond_404(request)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_remote_media_info(self, server_name, media_id):
|
||||||
|
"""Gets the media info associated with the remote file, downloading
|
||||||
|
if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name (str): Remote server_name where the media originated.
|
||||||
|
media_id (str): The media ID of the content (as defined by the
|
||||||
|
remote server).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[dict]: The media_info of the file
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
self.federation_domain_whitelist is not None and
|
||||||
|
server_name not in self.federation_domain_whitelist
|
||||||
|
):
|
||||||
|
raise FederationDeniedError(server_name)
|
||||||
|
|
||||||
|
# We linearize here to ensure that we don't try and download remote
|
||||||
|
# media multiple times concurrently
|
||||||
|
key = (server_name, media_id)
|
||||||
|
with (yield self.remote_media_linearizer.queue(key)):
|
||||||
|
responder, media_info = yield self._get_remote_media_impl(
|
||||||
|
server_name, media_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure we actually use the responder so that it releases resources
|
||||||
|
if responder:
|
||||||
|
with responder:
|
||||||
|
pass
|
||||||
|
|
||||||
defer.returnValue(media_info)
|
defer.returnValue(media_info)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_remote_media_impl(self, server_name, media_id):
|
def _get_remote_media_impl(self, server_name, media_id):
|
||||||
|
"""Looks for media in local cache, if not there then attempt to
|
||||||
|
download from remote server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name (str): Remote server_name where the media originated.
|
||||||
|
media_id (str): The media ID of the content (as defined by the
|
||||||
|
remote server).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[(Responder, media_info)]
|
||||||
|
"""
|
||||||
media_info = yield self.store.get_cached_remote_media(
|
media_info = yield self.store.get_cached_remote_media(
|
||||||
server_name, media_id
|
server_name, media_id
|
||||||
)
|
)
|
||||||
if not media_info:
|
|
||||||
media_info = yield self._download_remote_file(
|
# file_id is the ID we use to track the file locally. If we've already
|
||||||
server_name, media_id
|
# seen the file then reuse the existing ID, otherwise genereate a new
|
||||||
)
|
# one.
|
||||||
elif media_info["quarantined_by"]:
|
if media_info:
|
||||||
raise NotFoundError()
|
file_id = media_info["filesystem_id"]
|
||||||
else:
|
else:
|
||||||
self.recently_accessed_remotes.add((server_name, media_id))
|
file_id = random_string(24)
|
||||||
yield self.store.update_cached_last_access_time(
|
|
||||||
[(server_name, media_id)], self.clock.time_msec()
|
file_info = FileInfo(server_name, file_id)
|
||||||
)
|
|
||||||
defer.returnValue(media_info)
|
# If we have an entry in the DB, try and look for it
|
||||||
|
if media_info:
|
||||||
|
if media_info["quarantined_by"]:
|
||||||
|
logger.info("Media is quarantined")
|
||||||
|
raise NotFoundError()
|
||||||
|
|
||||||
|
responder = yield self.media_storage.fetch_media(file_info)
|
||||||
|
if responder:
|
||||||
|
defer.returnValue((responder, media_info))
|
||||||
|
|
||||||
|
# Failed to find the file anywhere, lets download it.
|
||||||
|
|
||||||
|
media_info = yield self._download_remote_file(
|
||||||
|
server_name, media_id, file_id
|
||||||
|
)
|
||||||
|
|
||||||
|
responder = yield self.media_storage.fetch_media(file_info)
|
||||||
|
defer.returnValue((responder, media_info))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _download_remote_file(self, server_name, media_id):
|
def _download_remote_file(self, server_name, media_id, file_id):
|
||||||
file_id = random_string(24)
|
"""Attempt to download the remote file from the given server name,
|
||||||
|
using the given file_id as the local id.
|
||||||
|
|
||||||
fpath = self.filepaths.remote_media_filepath_rel(
|
Args:
|
||||||
server_name, file_id
|
server_name (str): Originating server
|
||||||
|
media_id (str): The media ID of the content (as defined by the
|
||||||
|
remote server). This is different than the file_id, which is
|
||||||
|
locally generated.
|
||||||
|
file_id (str): Local file ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[MediaInfo]
|
||||||
|
"""
|
||||||
|
|
||||||
|
file_info = FileInfo(
|
||||||
|
server_name=server_name,
|
||||||
|
file_id=file_id,
|
||||||
)
|
)
|
||||||
fname = os.path.join(self.primary_base_path, fpath)
|
|
||||||
self._makedirs(fname)
|
|
||||||
|
|
||||||
try:
|
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||||
with open(fname, "wb") as f:
|
request_path = "/".join((
|
||||||
request_path = "/".join((
|
"/_matrix/media/v1/download", server_name, media_id,
|
||||||
"/_matrix/media/v1/download", server_name, media_id,
|
))
|
||||||
))
|
try:
|
||||||
|
length, headers = yield self.client.get_file(
|
||||||
|
server_name, request_path, output_stream=f,
|
||||||
|
max_size=self.max_upload_size, args={
|
||||||
|
# tell the remote server to 404 if it doesn't
|
||||||
|
# recognise the server_name, to make sure we don't
|
||||||
|
# end up with a routing loop.
|
||||||
|
"allow_remote": "false",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except twisted.internet.error.DNSLookupError as e:
|
||||||
|
logger.warn("HTTP error fetching remote media %s/%s: %r",
|
||||||
|
server_name, media_id, e)
|
||||||
|
raise NotFoundError()
|
||||||
|
|
||||||
|
except HttpResponseException as e:
|
||||||
|
logger.warn("HTTP error fetching remote media %s/%s: %s",
|
||||||
|
server_name, media_id, e.response)
|
||||||
|
if e.code == twisted.web.http.NOT_FOUND:
|
||||||
|
raise SynapseError.from_http_response_exception(e)
|
||||||
|
raise SynapseError(502, "Failed to fetch remote media")
|
||||||
|
|
||||||
|
except SynapseError:
|
||||||
|
logger.exception("Failed to fetch remote media %s/%s",
|
||||||
|
server_name, media_id)
|
||||||
|
raise
|
||||||
|
except NotRetryingDestination:
|
||||||
|
logger.warn("Not retrying destination %r", server_name)
|
||||||
|
raise SynapseError(502, "Failed to fetch remote media")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to fetch remote media %s/%s",
|
||||||
|
server_name, media_id)
|
||||||
|
raise SynapseError(502, "Failed to fetch remote media")
|
||||||
|
|
||||||
|
yield finish()
|
||||||
|
|
||||||
|
media_type = headers["Content-Type"][0]
|
||||||
|
|
||||||
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
||||||
|
content_disposition = headers.get("Content-Disposition", None)
|
||||||
|
if content_disposition:
|
||||||
|
_, params = cgi.parse_header(content_disposition[0],)
|
||||||
|
upload_name = None
|
||||||
|
|
||||||
|
# First check if there is a valid UTF-8 filename
|
||||||
|
upload_name_utf8 = params.get("filename*", None)
|
||||||
|
if upload_name_utf8:
|
||||||
|
if upload_name_utf8.lower().startswith("utf-8''"):
|
||||||
|
upload_name = upload_name_utf8[7:]
|
||||||
|
|
||||||
|
# If there isn't check for an ascii name.
|
||||||
|
if not upload_name:
|
||||||
|
upload_name_ascii = params.get("filename", None)
|
||||||
|
if upload_name_ascii and is_ascii(upload_name_ascii):
|
||||||
|
upload_name = upload_name_ascii
|
||||||
|
|
||||||
|
if upload_name:
|
||||||
|
upload_name = urlparse.unquote(upload_name)
|
||||||
try:
|
try:
|
||||||
length, headers = yield self.client.get_file(
|
upload_name = upload_name.decode("utf-8")
|
||||||
server_name, request_path, output_stream=f,
|
except UnicodeDecodeError:
|
||||||
max_size=self.max_upload_size, args={
|
upload_name = None
|
||||||
# tell the remote server to 404 if it doesn't
|
else:
|
||||||
# recognise the server_name, to make sure we don't
|
upload_name = None
|
||||||
# end up with a routing loop.
|
|
||||||
"allow_remote": "false",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except twisted.internet.error.DNSLookupError as e:
|
|
||||||
logger.warn("HTTP error fetching remote media %s/%s: %r",
|
|
||||||
server_name, media_id, e)
|
|
||||||
raise NotFoundError()
|
|
||||||
|
|
||||||
except HttpResponseException as e:
|
logger.info("Stored remote media in file %r", fname)
|
||||||
logger.warn("HTTP error fetching remote media %s/%s: %s",
|
|
||||||
server_name, media_id, e.response)
|
|
||||||
if e.code == twisted.web.http.NOT_FOUND:
|
|
||||||
raise SynapseError.from_http_response_exception(e)
|
|
||||||
raise SynapseError(502, "Failed to fetch remote media")
|
|
||||||
|
|
||||||
except SynapseError:
|
yield self.store.store_cached_remote_media(
|
||||||
logger.exception("Failed to fetch remote media %s/%s",
|
origin=server_name,
|
||||||
server_name, media_id)
|
media_id=media_id,
|
||||||
raise
|
media_type=media_type,
|
||||||
except NotRetryingDestination:
|
time_now_ms=self.clock.time_msec(),
|
||||||
logger.warn("Not retrying destination %r", server_name)
|
upload_name=upload_name,
|
||||||
raise SynapseError(502, "Failed to fetch remote media")
|
media_length=length,
|
||||||
except Exception:
|
filesystem_id=file_id,
|
||||||
logger.exception("Failed to fetch remote media %s/%s",
|
)
|
||||||
server_name, media_id)
|
|
||||||
raise SynapseError(502, "Failed to fetch remote media")
|
|
||||||
|
|
||||||
yield self.copy_to_backup(fpath)
|
|
||||||
|
|
||||||
media_type = headers["Content-Type"][0]
|
|
||||||
time_now_ms = self.clock.time_msec()
|
|
||||||
|
|
||||||
content_disposition = headers.get("Content-Disposition", None)
|
|
||||||
if content_disposition:
|
|
||||||
_, params = cgi.parse_header(content_disposition[0],)
|
|
||||||
upload_name = None
|
|
||||||
|
|
||||||
# First check if there is a valid UTF-8 filename
|
|
||||||
upload_name_utf8 = params.get("filename*", None)
|
|
||||||
if upload_name_utf8:
|
|
||||||
if upload_name_utf8.lower().startswith("utf-8''"):
|
|
||||||
upload_name = upload_name_utf8[7:]
|
|
||||||
|
|
||||||
# If there isn't check for an ascii name.
|
|
||||||
if not upload_name:
|
|
||||||
upload_name_ascii = params.get("filename", None)
|
|
||||||
if upload_name_ascii and is_ascii(upload_name_ascii):
|
|
||||||
upload_name = upload_name_ascii
|
|
||||||
|
|
||||||
if upload_name:
|
|
||||||
upload_name = urlparse.unquote(upload_name)
|
|
||||||
try:
|
|
||||||
upload_name = upload_name.decode("utf-8")
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
upload_name = None
|
|
||||||
else:
|
|
||||||
upload_name = None
|
|
||||||
|
|
||||||
logger.info("Stored remote media in file %r", fname)
|
|
||||||
|
|
||||||
yield self.store.store_cached_remote_media(
|
|
||||||
origin=server_name,
|
|
||||||
media_id=media_id,
|
|
||||||
media_type=media_type,
|
|
||||||
time_now_ms=self.clock.time_msec(),
|
|
||||||
upload_name=upload_name,
|
|
||||||
media_length=length,
|
|
||||||
filesystem_id=file_id,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
os.remove(fname)
|
|
||||||
raise
|
|
||||||
|
|
||||||
media_info = {
|
media_info = {
|
||||||
"media_type": media_type,
|
"media_type": media_type,
|
||||||
@ -323,7 +438,7 @@ class MediaRepository(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
yield self._generate_thumbnails(
|
yield self._generate_thumbnails(
|
||||||
server_name, media_id, media_info
|
server_name, media_id, file_id, media_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(media_info)
|
defer.returnValue(media_info)
|
||||||
@ -368,11 +483,18 @@ class MediaRepository(object):
|
|||||||
|
|
||||||
if t_byte_source:
|
if t_byte_source:
|
||||||
try:
|
try:
|
||||||
output_path = yield self.write_to_file_and_backup(
|
file_info = FileInfo(
|
||||||
t_byte_source,
|
server_name=None,
|
||||||
self.filepaths.local_media_thumbnail_rel(
|
file_id=media_id,
|
||||||
media_id, t_width, t_height, t_type, t_method
|
thumbnail=True,
|
||||||
)
|
thumbnail_width=t_width,
|
||||||
|
thumbnail_height=t_height,
|
||||||
|
thumbnail_method=t_method,
|
||||||
|
thumbnail_type=t_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_path = yield self.media_storage.store_file(
|
||||||
|
t_byte_source, file_info,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
t_byte_source.close()
|
t_byte_source.close()
|
||||||
@ -400,11 +522,18 @@ class MediaRepository(object):
|
|||||||
|
|
||||||
if t_byte_source:
|
if t_byte_source:
|
||||||
try:
|
try:
|
||||||
output_path = yield self.write_to_file_and_backup(
|
file_info = FileInfo(
|
||||||
t_byte_source,
|
server_name=server_name,
|
||||||
self.filepaths.remote_media_thumbnail_rel(
|
file_id=media_id,
|
||||||
server_name, file_id, t_width, t_height, t_type, t_method
|
thumbnail=True,
|
||||||
)
|
thumbnail_width=t_width,
|
||||||
|
thumbnail_height=t_height,
|
||||||
|
thumbnail_method=t_method,
|
||||||
|
thumbnail_type=t_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_path = yield self.media_storage.store_file(
|
||||||
|
t_byte_source, file_info,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
t_byte_source.close()
|
t_byte_source.close()
|
||||||
@ -421,21 +550,22 @@ class MediaRepository(object):
|
|||||||
defer.returnValue(output_path)
|
defer.returnValue(output_path)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False):
|
def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
|
||||||
|
url_cache=False):
|
||||||
"""Generate and store thumbnails for an image.
|
"""Generate and store thumbnails for an image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_name(str|None): The server name if remote media, else None if local
|
server_name (str|None): The server name if remote media, else None if local
|
||||||
media_id(str)
|
media_id (str): The media ID of the content. (This is the same as
|
||||||
media_info(dict)
|
the file_id for local content)
|
||||||
url_cache(bool): If we are thumbnailing images downloaded for the URL cache,
|
file_id (str): Local file ID
|
||||||
|
media_type (str): The content type of the file
|
||||||
|
url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
|
||||||
used exclusively by the url previewer
|
used exclusively by the url previewer
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict]: Dict with "width" and "height" keys of original image
|
Deferred[dict]: Dict with "width" and "height" keys of original image
|
||||||
"""
|
"""
|
||||||
media_type = media_info["media_type"]
|
|
||||||
file_id = media_info.get("filesystem_id")
|
|
||||||
requirements = self._get_thumbnail_requirements(media_type)
|
requirements = self._get_thumbnail_requirements(media_type)
|
||||||
if not requirements:
|
if not requirements:
|
||||||
return
|
return
|
||||||
@ -472,20 +602,6 @@ class MediaRepository(object):
|
|||||||
|
|
||||||
# Now we generate the thumbnails for each dimension, store it
|
# Now we generate the thumbnails for each dimension, store it
|
||||||
for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
|
for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
|
||||||
# Work out the correct file name for thumbnail
|
|
||||||
if server_name:
|
|
||||||
file_path = self.filepaths.remote_media_thumbnail_rel(
|
|
||||||
server_name, file_id, t_width, t_height, t_type, t_method
|
|
||||||
)
|
|
||||||
elif url_cache:
|
|
||||||
file_path = self.filepaths.url_cache_thumbnail_rel(
|
|
||||||
media_id, t_width, t_height, t_type, t_method
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
file_path = self.filepaths.local_media_thumbnail_rel(
|
|
||||||
media_id, t_width, t_height, t_type, t_method
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate the thumbnail
|
# Generate the thumbnail
|
||||||
if t_method == "crop":
|
if t_method == "crop":
|
||||||
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||||
@ -505,9 +621,19 @@ class MediaRepository(object):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Write to disk
|
file_info = FileInfo(
|
||||||
output_path = yield self.write_to_file_and_backup(
|
server_name=server_name,
|
||||||
t_byte_source, file_path,
|
file_id=file_id,
|
||||||
|
thumbnail=True,
|
||||||
|
thumbnail_width=t_width,
|
||||||
|
thumbnail_height=t_height,
|
||||||
|
thumbnail_method=t_method,
|
||||||
|
thumbnail_type=t_type,
|
||||||
|
url_cache=url_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_path = yield self.media_storage.store_file(
|
||||||
|
t_byte_source, file_info,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
t_byte_source.close()
|
t_byte_source.close()
|
||||||
@ -620,7 +746,11 @@ class MediaRepositoryResource(Resource):
|
|||||||
|
|
||||||
self.putChild("upload", UploadResource(hs, media_repo))
|
self.putChild("upload", UploadResource(hs, media_repo))
|
||||||
self.putChild("download", DownloadResource(hs, media_repo))
|
self.putChild("download", DownloadResource(hs, media_repo))
|
||||||
self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
|
self.putChild("thumbnail", ThumbnailResource(
|
||||||
|
hs, media_repo, media_repo.media_storage,
|
||||||
|
))
|
||||||
self.putChild("identicon", IdenticonResource())
|
self.putChild("identicon", IdenticonResource())
|
||||||
if hs.config.url_preview_enabled:
|
if hs.config.url_preview_enabled:
|
||||||
self.putChild("preview_url", PreviewUrlResource(hs, media_repo))
|
self.putChild("preview_url", PreviewUrlResource(
|
||||||
|
hs, media_repo, media_repo.media_storage,
|
||||||
|
))
|
||||||
|
236
synapse/rest/media/v1/media_storage.py
Normal file
236
synapse/rest/media/v1/media_storage.py
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vecotr Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer, threads
|
||||||
|
from twisted.protocols.basic import FileSender
|
||||||
|
|
||||||
|
from ._base import Responder
|
||||||
|
|
||||||
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MediaStorage(object):
|
||||||
|
"""Responsible for storing/fetching files from local sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_media_directory (str): Base path where we store media on disk
|
||||||
|
filepaths (MediaFilePaths)
|
||||||
|
storage_providers ([StorageProvider]): List of StorageProvider that are
|
||||||
|
used to fetch and store files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, local_media_directory, filepaths, storage_providers):
|
||||||
|
self.local_media_directory = local_media_directory
|
||||||
|
self.filepaths = filepaths
|
||||||
|
self.storage_providers = storage_providers
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def store_file(self, source, file_info):
|
||||||
|
"""Write `source` to the on disk media store, and also any other
|
||||||
|
configured storage providers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: A file like object that should be written
|
||||||
|
file_info (FileInfo): Info about the file to store
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[str]: the file path written to in the primary media store
|
||||||
|
"""
|
||||||
|
path = self._file_info_to_path(file_info)
|
||||||
|
fname = os.path.join(self.local_media_directory, path)
|
||||||
|
|
||||||
|
dirname = os.path.dirname(fname)
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
os.makedirs(dirname)
|
||||||
|
|
||||||
|
# Write to the main repository
|
||||||
|
yield make_deferred_yieldable(threads.deferToThread(
|
||||||
|
_write_file_synchronously, source, fname,
|
||||||
|
))
|
||||||
|
|
||||||
|
defer.returnValue(fname)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def store_into_file(self, file_info):
|
||||||
|
"""Context manager used to get a file like object to write into, as
|
||||||
|
described by file_info.
|
||||||
|
|
||||||
|
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
|
||||||
|
like object that can be written to, fname is the absolute path of file
|
||||||
|
on disk, and finish_cb is a function that returns a Deferred.
|
||||||
|
|
||||||
|
fname can be used to read the contents from after upload, e.g. to
|
||||||
|
generate thumbnails.
|
||||||
|
|
||||||
|
finish_cb must be called and waited on after the file has been
|
||||||
|
successfully been written to. Should not be called if there was an
|
||||||
|
error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_info (FileInfo): Info about the file to store
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
with media_storage.store_into_file(info) as (f, fname, finish_cb):
|
||||||
|
# .. write into f ...
|
||||||
|
yield finish_cb()
|
||||||
|
"""
|
||||||
|
|
||||||
|
path = self._file_info_to_path(file_info)
|
||||||
|
fname = os.path.join(self.local_media_directory, path)
|
||||||
|
|
||||||
|
dirname = os.path.dirname(fname)
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
os.makedirs(dirname)
|
||||||
|
|
||||||
|
finished_called = [False]
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def finish():
|
||||||
|
for provider in self.storage_providers:
|
||||||
|
yield provider.store_file(path, file_info)
|
||||||
|
|
||||||
|
finished_called[0] = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(fname, "wb") as f:
|
||||||
|
yield f, fname, finish
|
||||||
|
except Exception:
|
||||||
|
t, v, tb = sys.exc_info()
|
||||||
|
try:
|
||||||
|
os.remove(fname)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise t, v, tb
|
||||||
|
|
||||||
|
if not finished_called:
|
||||||
|
raise Exception("Finished callback not called")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def fetch_media(self, file_info):
|
||||||
|
"""Attempts to fetch media described by file_info from the local cache
|
||||||
|
and configured storage providers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_info (FileInfo)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[Responder|None]: Returns a Responder if the file was found,
|
||||||
|
otherwise None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path = self._file_info_to_path(file_info)
|
||||||
|
local_path = os.path.join(self.local_media_directory, path)
|
||||||
|
if os.path.exists(local_path):
|
||||||
|
defer.returnValue(FileResponder(open(local_path, "rb")))
|
||||||
|
|
||||||
|
for provider in self.storage_providers:
|
||||||
|
res = yield provider.fetch(path, file_info)
|
||||||
|
if res:
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
def _file_info_to_path(self, file_info):
|
||||||
|
"""Converts file_info into a relative path.
|
||||||
|
|
||||||
|
The path is suitable for storing files under a directory, e.g. used to
|
||||||
|
store files on local FS under the base media repository directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_info (FileInfo)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str
|
||||||
|
"""
|
||||||
|
if file_info.url_cache:
|
||||||
|
if file_info.thumbnail:
|
||||||
|
return self.filepaths.url_cache_thumbnail_rel(
|
||||||
|
media_id=file_info.file_id,
|
||||||
|
width=file_info.thumbnail_width,
|
||||||
|
height=file_info.thumbnail_height,
|
||||||
|
content_type=file_info.thumbnail_type,
|
||||||
|
method=file_info.thumbnail_method,
|
||||||
|
)
|
||||||
|
return self.filepaths.url_cache_filepath_rel(file_info.file_id)
|
||||||
|
|
||||||
|
if file_info.server_name:
|
||||||
|
if file_info.thumbnail:
|
||||||
|
return self.filepaths.remote_media_thumbnail_rel(
|
||||||
|
server_name=file_info.server_name,
|
||||||
|
file_id=file_info.file_id,
|
||||||
|
width=file_info.thumbnail_width,
|
||||||
|
height=file_info.thumbnail_height,
|
||||||
|
content_type=file_info.thumbnail_type,
|
||||||
|
method=file_info.thumbnail_method
|
||||||
|
)
|
||||||
|
return self.filepaths.remote_media_filepath_rel(
|
||||||
|
file_info.server_name, file_info.file_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if file_info.thumbnail:
|
||||||
|
return self.filepaths.local_media_thumbnail_rel(
|
||||||
|
media_id=file_info.file_id,
|
||||||
|
width=file_info.thumbnail_width,
|
||||||
|
height=file_info.thumbnail_height,
|
||||||
|
content_type=file_info.thumbnail_type,
|
||||||
|
method=file_info.thumbnail_method
|
||||||
|
)
|
||||||
|
return self.filepaths.local_media_filepath_rel(
|
||||||
|
file_info.file_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_file_synchronously(source, fname):
|
||||||
|
"""Write `source` to the path `fname` synchronously. Should be called
|
||||||
|
from a thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: A file like object to be written
|
||||||
|
fname (str): Path to write to
|
||||||
|
"""
|
||||||
|
dirname = os.path.dirname(fname)
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
os.makedirs(dirname)
|
||||||
|
|
||||||
|
source.seek(0) # Ensure we read from the start of the file
|
||||||
|
with open(fname, "wb") as f:
|
||||||
|
shutil.copyfileobj(source, f)
|
||||||
|
|
||||||
|
|
||||||
|
class FileResponder(Responder):
|
||||||
|
"""Wraps an open file that can be sent to a request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
open_file (file): A file like object to be streamed ot the client,
|
||||||
|
is closed when finished streaming.
|
||||||
|
"""
|
||||||
|
def __init__(self, open_file):
|
||||||
|
self.open_file = open_file
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def write_to_consumer(self, consumer):
|
||||||
|
yield FileSender().beginFileTransfer(self.open_file, consumer)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.open_file.close()
|
@ -17,6 +17,8 @@ from twisted.web.server import NOT_DONE_YET
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
|
from ._base import FileInfo
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
SynapseError, Codes,
|
SynapseError, Codes,
|
||||||
)
|
)
|
||||||
@ -49,7 +51,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class PreviewUrlResource(Resource):
|
class PreviewUrlResource(Resource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs, media_repo):
|
def __init__(self, hs, media_repo, media_storage):
|
||||||
Resource.__init__(self)
|
Resource.__init__(self)
|
||||||
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
@ -62,6 +64,7 @@ class PreviewUrlResource(Resource):
|
|||||||
self.client = SpiderHttpClient(hs)
|
self.client = SpiderHttpClient(hs)
|
||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
self.primary_base_path = media_repo.primary_base_path
|
self.primary_base_path = media_repo.primary_base_path
|
||||||
|
self.media_storage = media_storage
|
||||||
|
|
||||||
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
|
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
|
||||||
|
|
||||||
@ -182,8 +185,10 @@ class PreviewUrlResource(Resource):
|
|||||||
logger.debug("got media_info of '%s'" % media_info)
|
logger.debug("got media_info of '%s'" % media_info)
|
||||||
|
|
||||||
if _is_media(media_info['media_type']):
|
if _is_media(media_info['media_type']):
|
||||||
|
file_id = media_info['filesystem_id']
|
||||||
dims = yield self.media_repo._generate_thumbnails(
|
dims = yield self.media_repo._generate_thumbnails(
|
||||||
None, media_info['filesystem_id'], media_info, url_cache=True,
|
None, file_id, file_id, media_info["media_type"],
|
||||||
|
url_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
og = {
|
og = {
|
||||||
@ -228,8 +233,10 @@ class PreviewUrlResource(Resource):
|
|||||||
|
|
||||||
if _is_media(image_info['media_type']):
|
if _is_media(image_info['media_type']):
|
||||||
# TODO: make sure we don't choke on white-on-transparent images
|
# TODO: make sure we don't choke on white-on-transparent images
|
||||||
|
file_id = image_info['filesystem_id']
|
||||||
dims = yield self.media_repo._generate_thumbnails(
|
dims = yield self.media_repo._generate_thumbnails(
|
||||||
None, image_info['filesystem_id'], image_info, url_cache=True,
|
None, file_id, file_id, image_info["media_type"],
|
||||||
|
url_cache=True,
|
||||||
)
|
)
|
||||||
if dims:
|
if dims:
|
||||||
og["og:image:width"] = dims['width']
|
og["og:image:width"] = dims['width']
|
||||||
@ -273,19 +280,21 @@ class PreviewUrlResource(Resource):
|
|||||||
|
|
||||||
file_id = datetime.date.today().isoformat() + '_' + random_string(16)
|
file_id = datetime.date.today().isoformat() + '_' + random_string(16)
|
||||||
|
|
||||||
fpath = self.filepaths.url_cache_filepath_rel(file_id)
|
file_info = FileInfo(
|
||||||
fname = os.path.join(self.primary_base_path, fpath)
|
server_name=None,
|
||||||
self.media_repo._makedirs(fname)
|
file_id=file_id,
|
||||||
|
url_cache=True,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(fname, "wb") as f:
|
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||||
logger.debug("Trying to get url '%s'" % url)
|
logger.debug("Trying to get url '%s'" % url)
|
||||||
length, headers, uri, code = yield self.client.get_file(
|
length, headers, uri, code = yield self.client.get_file(
|
||||||
url, output_stream=f, max_size=self.max_spider_size,
|
url, output_stream=f, max_size=self.max_spider_size,
|
||||||
)
|
)
|
||||||
# FIXME: pass through 404s and other error messages nicely
|
# FIXME: pass through 404s and other error messages nicely
|
||||||
|
|
||||||
yield self.media_repo.copy_to_backup(fpath)
|
yield finish()
|
||||||
|
|
||||||
media_type = headers["Content-Type"][0]
|
media_type = headers["Content-Type"][0]
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
@ -327,7 +336,6 @@ class PreviewUrlResource(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
os.remove(fname)
|
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
500, ("Failed to download content: %s" % e),
|
500, ("Failed to download content: %s" % e),
|
||||||
Codes.UNKNOWN
|
Codes.UNKNOWN
|
||||||
|
140
synapse/rest/media/v1/storage_provider.py
Normal file
140
synapse/rest/media/v1/storage_provider.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer, threads
|
||||||
|
|
||||||
|
from .media_storage import FileResponder
|
||||||
|
|
||||||
|
from synapse.config._base import Config
|
||||||
|
from synapse.util.logcontext import preserve_fn
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StorageProvider(object):
|
||||||
|
"""A storage provider is a service that can store uploaded media and
|
||||||
|
retrieve them.
|
||||||
|
"""
|
||||||
|
def store_file(self, path, file_info):
|
||||||
|
"""Store the file described by file_info. The actual contents can be
|
||||||
|
retrieved by reading the file in file_info.upload_path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Relative path of file in local cache
|
||||||
|
file_info (FileInfo)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fetch(self, path, file_info):
|
||||||
|
"""Attempt to fetch the file described by file_info and stream it
|
||||||
|
into writer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Relative path of file in local cache
|
||||||
|
file_info (FileInfo)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred(Responder): Returns a Responder if the provider has the file,
|
||||||
|
otherwise returns None.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StorageProviderWrapper(StorageProvider):
|
||||||
|
"""Wraps a storage provider and provides various config options
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend (StorageProvider)
|
||||||
|
store_local (bool): Whether to store new local files or not.
|
||||||
|
store_synchronous (bool): Whether to wait for file to be successfully
|
||||||
|
uploaded, or todo the upload in the backgroud.
|
||||||
|
store_remote (bool): Whether remote media should be uploaded
|
||||||
|
"""
|
||||||
|
def __init__(self, backend, store_local, store_synchronous, store_remote):
|
||||||
|
self.backend = backend
|
||||||
|
self.store_local = store_local
|
||||||
|
self.store_synchronous = store_synchronous
|
||||||
|
self.store_remote = store_remote
|
||||||
|
|
||||||
|
def store_file(self, path, file_info):
|
||||||
|
if not file_info.server_name and not self.store_local:
|
||||||
|
return defer.succeed(None)
|
||||||
|
|
||||||
|
if file_info.server_name and not self.store_remote:
|
||||||
|
return defer.succeed(None)
|
||||||
|
|
||||||
|
if self.store_synchronous:
|
||||||
|
return self.backend.store_file(path, file_info)
|
||||||
|
else:
|
||||||
|
# TODO: Handle errors.
|
||||||
|
preserve_fn(self.backend.store_file)(path, file_info)
|
||||||
|
return defer.succeed(None)
|
||||||
|
|
||||||
|
def fetch(self, path, file_info):
|
||||||
|
return self.backend.fetch(path, file_info)
|
||||||
|
|
||||||
|
|
||||||
|
class FileStorageProviderBackend(StorageProvider):
|
||||||
|
"""A storage provider that stores files in a directory on a filesystem.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hs (HomeServer)
|
||||||
|
config: The config returned by `parse_config`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hs, config):
|
||||||
|
self.cache_directory = hs.config.media_store_path
|
||||||
|
self.base_directory = config
|
||||||
|
|
||||||
|
def store_file(self, path, file_info):
|
||||||
|
"""See StorageProvider.store_file"""
|
||||||
|
|
||||||
|
primary_fname = os.path.join(self.cache_directory, path)
|
||||||
|
backup_fname = os.path.join(self.base_directory, path)
|
||||||
|
|
||||||
|
dirname = os.path.dirname(backup_fname)
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
os.makedirs(dirname)
|
||||||
|
|
||||||
|
return threads.deferToThread(
|
||||||
|
shutil.copyfile, primary_fname, backup_fname,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fetch(self, path, file_info):
|
||||||
|
"""See StorageProvider.fetch"""
|
||||||
|
|
||||||
|
backup_fname = os.path.join(self.base_directory, path)
|
||||||
|
if os.path.isfile(backup_fname):
|
||||||
|
return FileResponder(open(backup_fname, "rb"))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_config(config):
|
||||||
|
"""Called on startup to parse config supplied. This should parse
|
||||||
|
the config and raise if there is a problem.
|
||||||
|
|
||||||
|
The returned value is passed into the constructor.
|
||||||
|
|
||||||
|
In this case we only care about a single param, the directory, so let's
|
||||||
|
just pull that out.
|
||||||
|
"""
|
||||||
|
return Config.ensure_directory(config["directory"])
|
@ -14,7 +14,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from ._base import parse_media_id, respond_404, respond_with_file
|
from ._base import (
|
||||||
|
parse_media_id, respond_404, respond_with_file, FileInfo,
|
||||||
|
respond_with_responder,
|
||||||
|
)
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from synapse.http.servlet import parse_string, parse_integer
|
from synapse.http.servlet import parse_string, parse_integer
|
||||||
from synapse.http.server import request_handler, set_cors_headers
|
from synapse.http.server import request_handler, set_cors_headers
|
||||||
@ -30,12 +33,12 @@ logger = logging.getLogger(__name__)
|
|||||||
class ThumbnailResource(Resource):
|
class ThumbnailResource(Resource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs, media_repo):
|
def __init__(self, hs, media_repo, media_storage):
|
||||||
Resource.__init__(self)
|
Resource.__init__(self)
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.filepaths = media_repo.filepaths
|
|
||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
|
self.media_storage = media_storage
|
||||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.version_string = hs.version_string
|
self.version_string = hs.version_string
|
||||||
@ -64,6 +67,7 @@ class ThumbnailResource(Resource):
|
|||||||
yield self._respond_local_thumbnail(
|
yield self._respond_local_thumbnail(
|
||||||
request, media_id, width, height, method, m_type
|
request, media_id, width, height, method, m_type
|
||||||
)
|
)
|
||||||
|
self.media_repo.mark_recently_accessed(None, media_id)
|
||||||
else:
|
else:
|
||||||
if self.dynamic_thumbnails:
|
if self.dynamic_thumbnails:
|
||||||
yield self._select_or_generate_remote_thumbnail(
|
yield self._select_or_generate_remote_thumbnail(
|
||||||
@ -75,20 +79,20 @@ class ThumbnailResource(Resource):
|
|||||||
request, server_name, media_id,
|
request, server_name, media_id,
|
||||||
width, height, method, m_type
|
width, height, method, m_type
|
||||||
)
|
)
|
||||||
|
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _respond_local_thumbnail(self, request, media_id, width, height,
|
def _respond_local_thumbnail(self, request, media_id, width, height,
|
||||||
method, m_type):
|
method, m_type):
|
||||||
media_info = yield self.store.get_local_media(media_id)
|
media_info = yield self.store.get_local_media(media_id)
|
||||||
|
|
||||||
if not media_info or media_info["quarantined_by"]:
|
if not media_info:
|
||||||
|
respond_404(request)
|
||||||
|
return
|
||||||
|
if media_info["quarantined_by"]:
|
||||||
|
logger.info("Media is quarantined")
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
|
||||||
# if media_info["media_type"] == "image/svg+xml":
|
|
||||||
# file_path = self.filepaths.local_media_filepath(media_id)
|
|
||||||
# yield respond_with_file(request, media_info["media_type"], file_path)
|
|
||||||
# return
|
|
||||||
|
|
||||||
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
|
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
|
||||||
|
|
||||||
@ -96,42 +100,39 @@ class ThumbnailResource(Resource):
|
|||||||
thumbnail_info = self._select_thumbnail(
|
thumbnail_info = self._select_thumbnail(
|
||||||
width, height, method, m_type, thumbnail_infos
|
width, height, method, m_type, thumbnail_infos
|
||||||
)
|
)
|
||||||
t_width = thumbnail_info["thumbnail_width"]
|
|
||||||
t_height = thumbnail_info["thumbnail_height"]
|
|
||||||
t_type = thumbnail_info["thumbnail_type"]
|
|
||||||
t_method = thumbnail_info["thumbnail_method"]
|
|
||||||
|
|
||||||
if media_info["url_cache"]:
|
file_info = FileInfo(
|
||||||
# TODO: Check the file still exists, if it doesn't we can redownload
|
server_name=None, file_id=media_id,
|
||||||
# it from the url `media_info["url_cache"]`
|
url_cache=media_info["url_cache"],
|
||||||
file_path = self.filepaths.url_cache_thumbnail(
|
thumbnail=True,
|
||||||
media_id, t_width, t_height, t_type, t_method,
|
thumbnail_width=thumbnail_info["thumbnail_width"],
|
||||||
)
|
thumbnail_height=thumbnail_info["thumbnail_height"],
|
||||||
else:
|
thumbnail_type=thumbnail_info["thumbnail_type"],
|
||||||
file_path = self.filepaths.local_media_thumbnail(
|
thumbnail_method=thumbnail_info["thumbnail_method"],
|
||||||
media_id, t_width, t_height, t_type, t_method,
|
|
||||||
)
|
|
||||||
yield respond_with_file(request, t_type, file_path)
|
|
||||||
|
|
||||||
else:
|
|
||||||
yield self._respond_default_thumbnail(
|
|
||||||
request, media_info, width, height, method, m_type,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
t_type = file_info.thumbnail_type
|
||||||
|
t_length = thumbnail_info["thumbnail_length"]
|
||||||
|
|
||||||
|
responder = yield self.media_storage.fetch_media(file_info)
|
||||||
|
yield respond_with_responder(request, responder, t_type, t_length)
|
||||||
|
else:
|
||||||
|
logger.info("Couldn't find any generated thumbnails")
|
||||||
|
respond_404(request)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
|
def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
|
||||||
desired_height, desired_method,
|
desired_height, desired_method,
|
||||||
desired_type):
|
desired_type):
|
||||||
media_info = yield self.store.get_local_media(media_id)
|
media_info = yield self.store.get_local_media(media_id)
|
||||||
|
|
||||||
if not media_info or media_info["quarantined_by"]:
|
if not media_info:
|
||||||
|
respond_404(request)
|
||||||
|
return
|
||||||
|
if media_info["quarantined_by"]:
|
||||||
|
logger.info("Media is quarantined")
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
|
||||||
# if media_info["media_type"] == "image/svg+xml":
|
|
||||||
# file_path = self.filepaths.local_media_filepath(media_id)
|
|
||||||
# yield respond_with_file(request, media_info["media_type"], file_path)
|
|
||||||
# return
|
|
||||||
|
|
||||||
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
|
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
|
||||||
for info in thumbnail_infos:
|
for info in thumbnail_infos:
|
||||||
@ -141,22 +142,25 @@ class ThumbnailResource(Resource):
|
|||||||
t_type = info["thumbnail_type"] == desired_type
|
t_type = info["thumbnail_type"] == desired_type
|
||||||
|
|
||||||
if t_w and t_h and t_method and t_type:
|
if t_w and t_h and t_method and t_type:
|
||||||
if media_info["url_cache"]:
|
file_info = FileInfo(
|
||||||
# TODO: Check the file still exists, if it doesn't we can redownload
|
server_name=None, file_id=media_id,
|
||||||
# it from the url `media_info["url_cache"]`
|
url_cache=media_info["url_cache"],
|
||||||
file_path = self.filepaths.url_cache_thumbnail(
|
thumbnail=True,
|
||||||
media_id, desired_width, desired_height, desired_type,
|
thumbnail_width=info["thumbnail_width"],
|
||||||
desired_method,
|
thumbnail_height=info["thumbnail_height"],
|
||||||
)
|
thumbnail_type=info["thumbnail_type"],
|
||||||
else:
|
thumbnail_method=info["thumbnail_method"],
|
||||||
file_path = self.filepaths.local_media_thumbnail(
|
)
|
||||||
media_id, desired_width, desired_height, desired_type,
|
|
||||||
desired_method,
|
|
||||||
)
|
|
||||||
yield respond_with_file(request, desired_type, file_path)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug("We don't have a local thumbnail of that size. Generating")
|
t_type = file_info.thumbnail_type
|
||||||
|
t_length = info["thumbnail_length"]
|
||||||
|
|
||||||
|
responder = yield self.media_storage.fetch_media(file_info)
|
||||||
|
if responder:
|
||||||
|
yield respond_with_responder(request, responder, t_type, t_length)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug("We don't have a thumbnail of that size. Generating")
|
||||||
|
|
||||||
# Okay, so we generate one.
|
# Okay, so we generate one.
|
||||||
file_path = yield self.media_repo.generate_local_exact_thumbnail(
|
file_path = yield self.media_repo.generate_local_exact_thumbnail(
|
||||||
@ -166,21 +170,14 @@ class ThumbnailResource(Resource):
|
|||||||
if file_path:
|
if file_path:
|
||||||
yield respond_with_file(request, desired_type, file_path)
|
yield respond_with_file(request, desired_type, file_path)
|
||||||
else:
|
else:
|
||||||
yield self._respond_default_thumbnail(
|
logger.warn("Failed to generate thumbnail")
|
||||||
request, media_info, desired_width, desired_height,
|
respond_404(request)
|
||||||
desired_method, desired_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
|
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
|
||||||
desired_width, desired_height,
|
desired_width, desired_height,
|
||||||
desired_method, desired_type):
|
desired_method, desired_type):
|
||||||
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
|
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
|
||||||
|
|
||||||
# if media_info["media_type"] == "image/svg+xml":
|
|
||||||
# file_path = self.filepaths.remote_media_filepath(server_name, media_id)
|
|
||||||
# yield respond_with_file(request, media_info["media_type"], file_path)
|
|
||||||
# return
|
|
||||||
|
|
||||||
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
|
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
|
||||||
server_name, media_id,
|
server_name, media_id,
|
||||||
@ -195,14 +192,24 @@ class ThumbnailResource(Resource):
|
|||||||
t_type = info["thumbnail_type"] == desired_type
|
t_type = info["thumbnail_type"] == desired_type
|
||||||
|
|
||||||
if t_w and t_h and t_method and t_type:
|
if t_w and t_h and t_method and t_type:
|
||||||
file_path = self.filepaths.remote_media_thumbnail(
|
file_info = FileInfo(
|
||||||
server_name, file_id, desired_width, desired_height,
|
server_name=server_name, file_id=media_info["filesystem_id"],
|
||||||
desired_type, desired_method,
|
thumbnail=True,
|
||||||
|
thumbnail_width=info["thumbnail_width"],
|
||||||
|
thumbnail_height=info["thumbnail_height"],
|
||||||
|
thumbnail_type=info["thumbnail_type"],
|
||||||
|
thumbnail_method=info["thumbnail_method"],
|
||||||
)
|
)
|
||||||
yield respond_with_file(request, desired_type, file_path)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug("We don't have a local thumbnail of that size. Generating")
|
t_type = file_info.thumbnail_type
|
||||||
|
t_length = info["thumbnail_length"]
|
||||||
|
|
||||||
|
responder = yield self.media_storage.fetch_media(file_info)
|
||||||
|
if responder:
|
||||||
|
yield respond_with_responder(request, responder, t_type, t_length)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug("We don't have a thumbnail of that size. Generating")
|
||||||
|
|
||||||
# Okay, so we generate one.
|
# Okay, so we generate one.
|
||||||
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
|
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
|
||||||
@ -213,22 +220,16 @@ class ThumbnailResource(Resource):
|
|||||||
if file_path:
|
if file_path:
|
||||||
yield respond_with_file(request, desired_type, file_path)
|
yield respond_with_file(request, desired_type, file_path)
|
||||||
else:
|
else:
|
||||||
yield self._respond_default_thumbnail(
|
logger.warn("Failed to generate thumbnail")
|
||||||
request, media_info, desired_width, desired_height,
|
respond_404(request)
|
||||||
desired_method, desired_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
|
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
|
||||||
height, method, m_type):
|
height, method, m_type):
|
||||||
# TODO: Don't download the whole remote file
|
# TODO: Don't download the whole remote file
|
||||||
# We should proxy the thumbnail from the remote server instead.
|
# We should proxy the thumbnail from the remote server instead of
|
||||||
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
|
# downloading the remote file and generating our own thumbnails.
|
||||||
|
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
|
||||||
# if media_info["media_type"] == "image/svg+xml":
|
|
||||||
# file_path = self.filepaths.remote_media_filepath(server_name, media_id)
|
|
||||||
# yield respond_with_file(request, media_info["media_type"], file_path)
|
|
||||||
# return
|
|
||||||
|
|
||||||
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
|
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
|
||||||
server_name, media_id,
|
server_name, media_id,
|
||||||
@ -238,59 +239,23 @@ class ThumbnailResource(Resource):
|
|||||||
thumbnail_info = self._select_thumbnail(
|
thumbnail_info = self._select_thumbnail(
|
||||||
width, height, method, m_type, thumbnail_infos
|
width, height, method, m_type, thumbnail_infos
|
||||||
)
|
)
|
||||||
t_width = thumbnail_info["thumbnail_width"]
|
file_info = FileInfo(
|
||||||
t_height = thumbnail_info["thumbnail_height"]
|
server_name=server_name, file_id=media_info["filesystem_id"],
|
||||||
t_type = thumbnail_info["thumbnail_type"]
|
thumbnail=True,
|
||||||
t_method = thumbnail_info["thumbnail_method"]
|
thumbnail_width=thumbnail_info["thumbnail_width"],
|
||||||
file_id = thumbnail_info["filesystem_id"]
|
thumbnail_height=thumbnail_info["thumbnail_height"],
|
||||||
|
thumbnail_type=thumbnail_info["thumbnail_type"],
|
||||||
|
thumbnail_method=thumbnail_info["thumbnail_method"],
|
||||||
|
)
|
||||||
|
|
||||||
|
t_type = file_info.thumbnail_type
|
||||||
t_length = thumbnail_info["thumbnail_length"]
|
t_length = thumbnail_info["thumbnail_length"]
|
||||||
|
|
||||||
file_path = self.filepaths.remote_media_thumbnail(
|
responder = yield self.media_storage.fetch_media(file_info)
|
||||||
server_name, file_id, t_width, t_height, t_type, t_method,
|
yield respond_with_responder(request, responder, t_type, t_length)
|
||||||
)
|
|
||||||
yield respond_with_file(request, t_type, file_path, t_length)
|
|
||||||
else:
|
else:
|
||||||
yield self._respond_default_thumbnail(
|
logger.info("Failed to find any generated thumbnails")
|
||||||
request, media_info, width, height, method, m_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _respond_default_thumbnail(self, request, media_info, width, height,
|
|
||||||
method, m_type):
|
|
||||||
# XXX: how is this meant to work? store.get_default_thumbnails
|
|
||||||
# appears to always return [] so won't this always 404?
|
|
||||||
media_type = media_info["media_type"]
|
|
||||||
top_level_type = media_type.split("/")[0]
|
|
||||||
sub_type = media_type.split("/")[-1].split(";")[0]
|
|
||||||
thumbnail_infos = yield self.store.get_default_thumbnails(
|
|
||||||
top_level_type, sub_type,
|
|
||||||
)
|
|
||||||
if not thumbnail_infos:
|
|
||||||
thumbnail_infos = yield self.store.get_default_thumbnails(
|
|
||||||
top_level_type, "_default",
|
|
||||||
)
|
|
||||||
if not thumbnail_infos:
|
|
||||||
thumbnail_infos = yield self.store.get_default_thumbnails(
|
|
||||||
"_default", "_default",
|
|
||||||
)
|
|
||||||
if not thumbnail_infos:
|
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
|
||||||
|
|
||||||
thumbnail_info = self._select_thumbnail(
|
|
||||||
width, height, "crop", m_type, thumbnail_infos
|
|
||||||
)
|
|
||||||
|
|
||||||
t_width = thumbnail_info["thumbnail_width"]
|
|
||||||
t_height = thumbnail_info["thumbnail_height"]
|
|
||||||
t_type = thumbnail_info["thumbnail_type"]
|
|
||||||
t_method = thumbnail_info["thumbnail_method"]
|
|
||||||
t_length = thumbnail_info["thumbnail_length"]
|
|
||||||
|
|
||||||
file_path = self.filepaths.default_thumbnail(
|
|
||||||
top_level_type, sub_type, t_width, t_height, t_type, t_method,
|
|
||||||
)
|
|
||||||
yield respond_with_file(request, t_type, file_path, t_length)
|
|
||||||
|
|
||||||
def _select_thumbnail(self, desired_width, desired_height, desired_method,
|
def _select_thumbnail(self, desired_width, desired_height, desired_method,
|
||||||
desired_type, thumbnail_infos):
|
desired_type, thumbnail_infos):
|
||||||
|
@ -307,6 +307,23 @@ class HomeServer(object):
|
|||||||
**self.db_config.get("args", {})
|
**self.db_config.get("args", {})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_db_conn(self, run_new_connection=True):
|
||||||
|
"""Makes a new connection to the database, skipping the db pool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Connection: a connection object implementing the PEP-249 spec
|
||||||
|
"""
|
||||||
|
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||||
|
# not be passed to the database engine.
|
||||||
|
db_params = {
|
||||||
|
k: v for k, v in self.db_config.get("args", {}).items()
|
||||||
|
if not k.startswith("cp_")
|
||||||
|
}
|
||||||
|
db_conn = self.database_engine.module.connect(**db_params)
|
||||||
|
if run_new_connection:
|
||||||
|
self.database_engine.on_new_connection(db_conn)
|
||||||
|
return db_conn
|
||||||
|
|
||||||
def build_media_repository_resource(self):
|
def build_media_repository_resource(self):
|
||||||
# build the media repo resource. This indirects through the HomeServer
|
# build the media repo resource. This indirects through the HomeServer
|
||||||
# to ensure that we only have a single instance of
|
# to ensure that we only have a single instance of
|
||||||
|
@ -146,8 +146,20 @@ class StateHandler(object):
|
|||||||
defer.returnValue(state)
|
defer.returnValue(state)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_current_state_ids(self, room_id, event_type=None, state_key="",
|
def get_current_state_ids(self, room_id, latest_event_ids=None):
|
||||||
latest_event_ids=None):
|
"""Get the current state, or the state at a set of events, for a room
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str):
|
||||||
|
|
||||||
|
latest_event_ids (iterable[str]|None): if given, the forward
|
||||||
|
extremities to resolve. If None, we look them up from the
|
||||||
|
database (via a cache)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[dict[(str, str), str)]]: the state dict, mapping from
|
||||||
|
(event_type, state_key) -> event_id
|
||||||
|
"""
|
||||||
if not latest_event_ids:
|
if not latest_event_ids:
|
||||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
|
|
||||||
@ -155,10 +167,6 @@ class StateHandler(object):
|
|||||||
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
|
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||||
state = ret.state
|
state = ret.state
|
||||||
|
|
||||||
if event_type:
|
|
||||||
defer.returnValue(state.get((event_type, state_key)))
|
|
||||||
return
|
|
||||||
|
|
||||||
defer.returnValue(state)
|
defer.returnValue(state)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -341,7 +349,7 @@ class StateHandler(object):
|
|||||||
if conflicted_state:
|
if conflicted_state:
|
||||||
logger.info("Resolving conflicted state for %r", room_id)
|
logger.info("Resolving conflicted state for %r", room_id)
|
||||||
with Measure(self.clock, "state._resolve_events"):
|
with Measure(self.clock, "state._resolve_events"):
|
||||||
new_state = yield resolve_events(
|
new_state = yield resolve_events_with_factory(
|
||||||
state_groups_ids.values(),
|
state_groups_ids.values(),
|
||||||
state_map_factory=lambda ev_ids: self.store.get_events(
|
state_map_factory=lambda ev_ids: self.store.get_events(
|
||||||
ev_ids, get_prev_content=False, check_redacted=False,
|
ev_ids, get_prev_content=False, check_redacted=False,
|
||||||
@ -404,7 +412,7 @@ class StateHandler(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
with Measure(self.clock, "state._resolve_events"):
|
with Measure(self.clock, "state._resolve_events"):
|
||||||
new_state = resolve_events(state_set_ids, state_map)
|
new_state = resolve_events_with_state_map(state_set_ids, state_map)
|
||||||
|
|
||||||
new_state = {
|
new_state = {
|
||||||
key: state_map[ev_id] for key, ev_id in new_state.items()
|
key: state_map[ev_id] for key, ev_id in new_state.items()
|
||||||
@ -420,19 +428,17 @@ def _ordered_events(events):
|
|||||||
return sorted(events, key=key_func)
|
return sorted(events, key=key_func)
|
||||||
|
|
||||||
|
|
||||||
def resolve_events(state_sets, state_map_factory):
|
def resolve_events_with_state_map(state_sets, state_map):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||||
which are the different state groups to resolve.
|
which are the different state groups to resolve.
|
||||||
state_map_factory(dict|callable): If callable, then will be called
|
state_map(dict): a dict from event_id to event, for all events in
|
||||||
with a list of event_ids that are needed, and should return with
|
state_sets.
|
||||||
a Deferred of dict of event_id to event. Otherwise, should be
|
|
||||||
a dict from event_id to event of all events in state_sets.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
dict[(str, str), synapse.events.FrozenEvent] is a map from
|
dict[(str, str), synapse.events.FrozenEvent]:
|
||||||
(type, state_key) to event.
|
a map from (type, state_key) to event.
|
||||||
"""
|
"""
|
||||||
if len(state_sets) == 1:
|
if len(state_sets) == 1:
|
||||||
return state_sets[0]
|
return state_sets[0]
|
||||||
@ -441,13 +447,6 @@ def resolve_events(state_sets, state_map_factory):
|
|||||||
state_sets,
|
state_sets,
|
||||||
)
|
)
|
||||||
|
|
||||||
if callable(state_map_factory):
|
|
||||||
return _resolve_with_state_fac(
|
|
||||||
unconflicted_state, conflicted_state, state_map_factory
|
|
||||||
)
|
|
||||||
|
|
||||||
state_map = state_map_factory
|
|
||||||
|
|
||||||
auth_events = _create_auth_events_from_maps(
|
auth_events = _create_auth_events_from_maps(
|
||||||
unconflicted_state, conflicted_state, state_map
|
unconflicted_state, conflicted_state, state_map
|
||||||
)
|
)
|
||||||
@ -491,8 +490,26 @@ def _seperate(state_sets):
|
|||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _resolve_with_state_fac(unconflicted_state, conflicted_state,
|
def resolve_events_with_factory(state_sets, state_map_factory):
|
||||||
state_map_factory):
|
"""
|
||||||
|
Args:
|
||||||
|
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||||
|
which are the different state groups to resolve.
|
||||||
|
state_map_factory(func): will be called
|
||||||
|
with a list of event_ids that are needed, and should return with
|
||||||
|
a Deferred of dict of event_id to event.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
Deferred[dict[(str, str), synapse.events.FrozenEvent]]:
|
||||||
|
a map from (type, state_key) to event.
|
||||||
|
"""
|
||||||
|
if len(state_sets) == 1:
|
||||||
|
defer.returnValue(state_sets[0])
|
||||||
|
|
||||||
|
unconflicted_state, conflicted_state = _seperate(
|
||||||
|
state_sets,
|
||||||
|
)
|
||||||
|
|
||||||
needed_events = set(
|
needed_events = set(
|
||||||
event_id
|
event_id
|
||||||
for event_ids in conflicted_state.itervalues()
|
for event_ids in conflicted_state.itervalues()
|
||||||
|
@ -291,33 +291,33 @@ class SQLBaseStore(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def runInteraction(self, desc, func, *args, **kwargs):
|
def runInteraction(self, desc, func, *args, **kwargs):
|
||||||
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
"""Starts a transaction on the database and runs a given function
|
||||||
current_context = LoggingContext.current_context()
|
|
||||||
|
|
||||||
start_time = time.time() * 1000
|
Arguments:
|
||||||
|
desc (str): description of the transaction, for logging and metrics
|
||||||
|
func (func): callback function, which will be called with a
|
||||||
|
database transaction (twisted.enterprise.adbapi.Transaction) as
|
||||||
|
its first argument, followed by `args` and `kwargs`.
|
||||||
|
|
||||||
|
args (list): positional args to pass to `func`
|
||||||
|
kwargs (dict): named args to pass to `func`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred: The result of func
|
||||||
|
"""
|
||||||
|
current_context = LoggingContext.current_context()
|
||||||
|
|
||||||
after_callbacks = []
|
after_callbacks = []
|
||||||
final_callbacks = []
|
final_callbacks = []
|
||||||
|
|
||||||
def inner_func(conn, *args, **kwargs):
|
def inner_func(conn, *args, **kwargs):
|
||||||
with LoggingContext("runInteraction") as context:
|
return self._new_transaction(
|
||||||
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
|
conn, desc, after_callbacks, final_callbacks, current_context,
|
||||||
|
func, *args, **kwargs
|
||||||
if self.database_engine.is_connection_closed(conn):
|
)
|
||||||
logger.debug("Reconnecting closed database connection")
|
|
||||||
conn.reconnect()
|
|
||||||
|
|
||||||
current_context.copy_to(context)
|
|
||||||
return self._new_transaction(
|
|
||||||
conn, desc, after_callbacks, final_callbacks, current_context,
|
|
||||||
func, *args, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with PreserveLoggingContext():
|
result = yield self.runWithConnection(inner_func, *args, **kwargs)
|
||||||
result = yield self._db_pool.runWithConnection(
|
|
||||||
inner_func, *args, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||||
after_callback(*after_args, **after_kwargs)
|
after_callback(*after_args, **after_kwargs)
|
||||||
@ -329,14 +329,27 @@ class SQLBaseStore(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def runWithConnection(self, func, *args, **kwargs):
|
def runWithConnection(self, func, *args, **kwargs):
|
||||||
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
"""Wraps the .runWithConnection() method on the underlying db_pool.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
func (func): callback function, which will be called with a
|
||||||
|
database connection (twisted.enterprise.adbapi.Connection) as
|
||||||
|
its first argument, followed by `args` and `kwargs`.
|
||||||
|
args (list): positional args to pass to `func`
|
||||||
|
kwargs (dict): named args to pass to `func`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred: The result of func
|
||||||
|
"""
|
||||||
current_context = LoggingContext.current_context()
|
current_context = LoggingContext.current_context()
|
||||||
|
|
||||||
start_time = time.time() * 1000
|
start_time = time.time() * 1000
|
||||||
|
|
||||||
def inner_func(conn, *args, **kwargs):
|
def inner_func(conn, *args, **kwargs):
|
||||||
with LoggingContext("runWithConnection") as context:
|
with LoggingContext("runWithConnection") as context:
|
||||||
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
|
sched_duration_ms = time.time() * 1000 - start_time
|
||||||
|
sql_scheduling_timer.inc_by(sched_duration_ms)
|
||||||
|
current_context.add_database_scheduled(sched_duration_ms)
|
||||||
|
|
||||||
if self.database_engine.is_connection_closed(conn):
|
if self.database_engine.is_connection_closed(conn):
|
||||||
logger.debug("Reconnecting closed database connection")
|
logger.debug("Reconnecting closed database connection")
|
||||||
|
@ -27,7 +27,7 @@ from synapse.util.logutils import log_function
|
|||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.state import resolve_events
|
from synapse.state import resolve_events_with_factory
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
|
|
||||||
@ -110,7 +110,7 @@ class _EventPeristenceQueue(object):
|
|||||||
end_item.events_and_contexts.extend(events_and_contexts)
|
end_item.events_and_contexts.extend(events_and_contexts)
|
||||||
return end_item.deferred.observe()
|
return end_item.deferred.observe()
|
||||||
|
|
||||||
deferred = ObservableDeferred(defer.Deferred())
|
deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
|
||||||
|
|
||||||
queue.append(self._EventPersistQueueItem(
|
queue.append(self._EventPersistQueueItem(
|
||||||
events_and_contexts=events_and_contexts,
|
events_and_contexts=events_and_contexts,
|
||||||
@ -146,18 +146,25 @@ class _EventPeristenceQueue(object):
|
|||||||
try:
|
try:
|
||||||
queue = self._get_drainining_queue(room_id)
|
queue = self._get_drainining_queue(room_id)
|
||||||
for item in queue:
|
for item in queue:
|
||||||
|
# handle_queue_loop runs in the sentinel logcontext, so
|
||||||
|
# there is no need to preserve_fn when running the
|
||||||
|
# callbacks on the deferred.
|
||||||
try:
|
try:
|
||||||
ret = yield per_item_callback(item)
|
ret = yield per_item_callback(item)
|
||||||
item.deferred.callback(ret)
|
item.deferred.callback(ret)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
item.deferred.errback(e)
|
item.deferred.errback()
|
||||||
finally:
|
finally:
|
||||||
queue = self._event_persist_queues.pop(room_id, None)
|
queue = self._event_persist_queues.pop(room_id, None)
|
||||||
if queue:
|
if queue:
|
||||||
self._event_persist_queues[room_id] = queue
|
self._event_persist_queues[room_id] = queue
|
||||||
self._currently_persisting_rooms.discard(room_id)
|
self._currently_persisting_rooms.discard(room_id)
|
||||||
|
|
||||||
preserve_fn(handle_queue_loop)()
|
# set handle_queue_loop off on the background. We don't want to
|
||||||
|
# attribute work done in it to the current request, so we drop the
|
||||||
|
# logcontext altogether.
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
handle_queue_loop()
|
||||||
|
|
||||||
def _get_drainining_queue(self, room_id):
|
def _get_drainining_queue(self, room_id):
|
||||||
queue = self._event_persist_queues.setdefault(room_id, deque())
|
queue = self._event_persist_queues.setdefault(room_id, deque())
|
||||||
@ -528,6 +535,12 @@ class EventsStore(SQLBaseStore):
|
|||||||
# the events we have yet to persist, so we need a slightly more
|
# the events we have yet to persist, so we need a slightly more
|
||||||
# complicated event lookup function than simply looking the events
|
# complicated event lookup function than simply looking the events
|
||||||
# up in the db.
|
# up in the db.
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Resolving state for %s with %i state sets",
|
||||||
|
room_id, len(state_sets),
|
||||||
|
)
|
||||||
|
|
||||||
events_map = {ev.event_id: ev for ev, _ in events_context}
|
events_map = {ev.event_id: ev for ev, _ in events_context}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -550,7 +563,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
to_return.update(evs)
|
to_return.update(evs)
|
||||||
defer.returnValue(to_return)
|
defer.returnValue(to_return)
|
||||||
|
|
||||||
current_state = yield resolve_events(
|
current_state = yield resolve_events_with_factory(
|
||||||
state_sets,
|
state_sets,
|
||||||
state_map_factory=get_events,
|
state_map_factory=get_events,
|
||||||
)
|
)
|
||||||
|
@ -29,9 +29,6 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
|||||||
where_clause='url_cache IS NOT NULL',
|
where_clause='url_cache IS NOT NULL',
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_default_thumbnails(self, top_level_type, sub_type):
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_local_media(self, media_id):
|
def get_local_media(self, media_id):
|
||||||
"""Get the metadata for a local piece of media
|
"""Get the metadata for a local piece of media
|
||||||
Returns:
|
Returns:
|
||||||
@ -176,7 +173,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
|||||||
desc="store_cached_remote_media",
|
desc="store_cached_remote_media",
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_cached_last_access_time(self, origin_id_tuples, time_ts):
|
def update_cached_last_access_time(self, local_media, remote_media, time_ms):
|
||||||
|
"""Updates the last access time of the given media
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_media (iterable[str]): Set of media_ids
|
||||||
|
remote_media (iterable[(str, str)]): Set of (server_name, media_id)
|
||||||
|
time_ms: Current time in milliseconds
|
||||||
|
"""
|
||||||
def update_cache_txn(txn):
|
def update_cache_txn(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"UPDATE remote_media_cache SET last_access_ts = ?"
|
"UPDATE remote_media_cache SET last_access_ts = ?"
|
||||||
@ -184,8 +188,18 @@ class MediaRepositoryStore(BackgroundUpdateStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
txn.executemany(sql, (
|
txn.executemany(sql, (
|
||||||
(time_ts, media_origin, media_id)
|
(time_ms, media_origin, media_id)
|
||||||
for media_origin, media_id in origin_id_tuples
|
for media_origin, media_id in remote_media
|
||||||
|
))
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"UPDATE local_media_repository SET last_access_ts = ?"
|
||||||
|
" WHERE media_id = ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.executemany(sql, (
|
||||||
|
(time_ms, media_id)
|
||||||
|
for media_id in local_media
|
||||||
))
|
))
|
||||||
|
|
||||||
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
|
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
|
||||||
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Remember to update this number every time a change is made to database
|
# Remember to update this number every time a change is made to database
|
||||||
# schema files, so the users will be informed on server restarts.
|
# schema files, so the users will be informed on server restarts.
|
||||||
SCHEMA_VERSION = 46
|
SCHEMA_VERSION = 47
|
||||||
|
|
||||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
@ -591,7 +591,7 @@ class RoomStore(SQLBaseStore):
|
|||||||
"""
|
"""
|
||||||
UPDATE remote_media_cache
|
UPDATE remote_media_cache
|
||||||
SET quarantined_by = ?
|
SET quarantined_by = ?
|
||||||
WHERE media_origin AND media_id = ?
|
WHERE media_origin = ? AND media_id = ?
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
(quarantined_by, origin, media_id)
|
(quarantined_by, origin, media_id)
|
||||||
|
16
synapse/storage/schema/delta/47/last_access_media.sql
Normal file
16
synapse/storage/schema/delta/47/last_access_media.sql
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
/* Copyright 2018 New Vector Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT;
|
@ -641,8 +641,12 @@ class UserDirectoryStore(SQLBaseStore):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if self.hs.config.user_directory_search_all_users:
|
if self.hs.config.user_directory_search_all_users:
|
||||||
join_clause = ""
|
# make s.user_id null to keep the ordering algorithm happy
|
||||||
where_clause = "?<>''" # naughty hack to keep the same number of binds
|
join_clause = """
|
||||||
|
CROSS JOIN (SELECT NULL as user_id) AS s
|
||||||
|
"""
|
||||||
|
join_args = ()
|
||||||
|
where_clause = "1=1"
|
||||||
else:
|
else:
|
||||||
join_clause = """
|
join_clause = """
|
||||||
LEFT JOIN users_in_public_rooms AS p USING (user_id)
|
LEFT JOIN users_in_public_rooms AS p USING (user_id)
|
||||||
@ -651,6 +655,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||||||
WHERE user_id = ? AND share_private
|
WHERE user_id = ? AND share_private
|
||||||
) AS s USING (user_id)
|
) AS s USING (user_id)
|
||||||
"""
|
"""
|
||||||
|
join_args = (user_id,)
|
||||||
where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
|
where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
|
||||||
|
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
@ -692,7 +697,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||||||
avatar_url IS NULL
|
avatar_url IS NULL
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
""" % (join_clause, where_clause)
|
""" % (join_clause, where_clause)
|
||||||
args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
|
args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
|
||||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||||
search_query = _parse_query_sqlite(search_term)
|
search_query = _parse_query_sqlite(search_term)
|
||||||
|
|
||||||
@ -710,7 +715,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||||||
avatar_url IS NULL
|
avatar_url IS NULL
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
""" % (join_clause, where_clause)
|
""" % (join_clause, where_clause)
|
||||||
args = (user_id, search_query, limit + 1)
|
args = join_args + (search_query, limit + 1)
|
||||||
else:
|
else:
|
||||||
# This should be unreachable.
|
# This should be unreachable.
|
||||||
raise Exception("Unrecognized database engine")
|
raise Exception("Unrecognized database engine")
|
||||||
|
139
synapse/util/file_consumer.py
Normal file
139
synapse/util/file_consumer.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import threads, reactor
|
||||||
|
|
||||||
|
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||||
|
|
||||||
|
import Queue
|
||||||
|
|
||||||
|
|
||||||
|
class BackgroundFileConsumer(object):
|
||||||
|
"""A consumer that writes to a file like object. Supports both push
|
||||||
|
and pull producers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_obj (file): The file like object to write to. Closed when
|
||||||
|
finished.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# For PushProducers pause if we have this many unwritten slices
|
||||||
|
_PAUSE_ON_QUEUE_SIZE = 5
|
||||||
|
# And resume once the size of the queue is less than this
|
||||||
|
_RESUME_ON_QUEUE_SIZE = 2
|
||||||
|
|
||||||
|
def __init__(self, file_obj):
|
||||||
|
self._file_obj = file_obj
|
||||||
|
|
||||||
|
# Producer we're registered with
|
||||||
|
self._producer = None
|
||||||
|
|
||||||
|
# True if PushProducer, false if PullProducer
|
||||||
|
self.streaming = False
|
||||||
|
|
||||||
|
# For PushProducers, indicates whether we've paused the producer and
|
||||||
|
# need to call resumeProducing before we get more data.
|
||||||
|
self._paused_producer = False
|
||||||
|
|
||||||
|
# Queue of slices of bytes to be written. When producer calls
|
||||||
|
# unregister a final None is sent.
|
||||||
|
self._bytes_queue = Queue.Queue()
|
||||||
|
|
||||||
|
# Deferred that is resolved when finished writing
|
||||||
|
self._finished_deferred = None
|
||||||
|
|
||||||
|
# If the _writer thread throws an exception it gets stored here.
|
||||||
|
self._write_exception = None
|
||||||
|
|
||||||
|
def registerProducer(self, producer, streaming):
|
||||||
|
"""Part of IConsumer interface
|
||||||
|
|
||||||
|
Args:
|
||||||
|
producer (IProducer)
|
||||||
|
streaming (bool): True if push based producer, False if pull
|
||||||
|
based.
|
||||||
|
"""
|
||||||
|
if self._producer:
|
||||||
|
raise Exception("registerProducer called twice")
|
||||||
|
|
||||||
|
self._producer = producer
|
||||||
|
self.streaming = streaming
|
||||||
|
self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
|
||||||
|
if not streaming:
|
||||||
|
self._producer.resumeProducing()
|
||||||
|
|
||||||
|
def unregisterProducer(self):
|
||||||
|
"""Part of IProducer interface
|
||||||
|
"""
|
||||||
|
self._producer = None
|
||||||
|
if not self._finished_deferred.called:
|
||||||
|
self._bytes_queue.put_nowait(None)
|
||||||
|
|
||||||
|
def write(self, bytes):
|
||||||
|
"""Part of IProducer interface
|
||||||
|
"""
|
||||||
|
if self._write_exception:
|
||||||
|
raise self._write_exception
|
||||||
|
|
||||||
|
if self._finished_deferred.called:
|
||||||
|
raise Exception("consumer has closed")
|
||||||
|
|
||||||
|
self._bytes_queue.put_nowait(bytes)
|
||||||
|
|
||||||
|
# If this is a PushProducer and the queue is getting behind
|
||||||
|
# then we pause the producer.
|
||||||
|
if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
|
||||||
|
self._paused_producer = True
|
||||||
|
self._producer.pauseProducing()
|
||||||
|
|
||||||
|
def _writer(self):
|
||||||
|
"""This is run in a background thread to write to the file.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
while self._producer or not self._bytes_queue.empty():
|
||||||
|
# If we've paused the producer check if we should resume the
|
||||||
|
# producer.
|
||||||
|
if self._producer and self._paused_producer:
|
||||||
|
if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
|
||||||
|
reactor.callFromThread(self._resume_paused_producer)
|
||||||
|
|
||||||
|
bytes = self._bytes_queue.get()
|
||||||
|
|
||||||
|
# If we get a None (or empty list) then that's a signal used
|
||||||
|
# to indicate we should check if we should stop.
|
||||||
|
if bytes:
|
||||||
|
self._file_obj.write(bytes)
|
||||||
|
|
||||||
|
# If its a pull producer then we need to explicitly ask for
|
||||||
|
# more stuff.
|
||||||
|
if not self.streaming and self._producer:
|
||||||
|
reactor.callFromThread(self._producer.resumeProducing)
|
||||||
|
except Exception as e:
|
||||||
|
self._write_exception = e
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._file_obj.close()
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
"""Returns a deferred that resolves when finished writing to file
|
||||||
|
"""
|
||||||
|
return make_deferred_yieldable(self._finished_deferred)
|
||||||
|
|
||||||
|
def _resume_paused_producer(self):
|
||||||
|
"""Gets called if we should resume producing after being paused
|
||||||
|
"""
|
||||||
|
if self._paused_producer and self._producer:
|
||||||
|
self._paused_producer = False
|
||||||
|
self._producer.resumeProducing()
|
@ -52,13 +52,17 @@ except Exception:
|
|||||||
class LoggingContext(object):
|
class LoggingContext(object):
|
||||||
"""Additional context for log formatting. Contexts are scoped within a
|
"""Additional context for log formatting. Contexts are scoped within a
|
||||||
"with" block.
|
"with" block.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): Name for the context for debugging.
|
name (str): Name for the context for debugging.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
"previous_context", "name", "usage_start", "usage_end", "main_thread",
|
"previous_context", "name", "ru_stime", "ru_utime",
|
||||||
"__dict__", "tag", "alive",
|
"db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
|
||||||
|
"usage_start", "usage_end",
|
||||||
|
"main_thread", "alive",
|
||||||
|
"request", "tag",
|
||||||
]
|
]
|
||||||
|
|
||||||
thread_local = threading.local()
|
thread_local = threading.local()
|
||||||
@ -83,6 +87,9 @@ class LoggingContext(object):
|
|||||||
def add_database_transaction(self, duration_ms):
|
def add_database_transaction(self, duration_ms):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def add_database_scheduled(self, sched_ms):
|
||||||
|
pass
|
||||||
|
|
||||||
def __nonzero__(self):
|
def __nonzero__(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -94,9 +101,17 @@ class LoggingContext(object):
|
|||||||
self.ru_stime = 0.
|
self.ru_stime = 0.
|
||||||
self.ru_utime = 0.
|
self.ru_utime = 0.
|
||||||
self.db_txn_count = 0
|
self.db_txn_count = 0
|
||||||
self.db_txn_duration = 0.
|
|
||||||
|
# ms spent waiting for db txns, excluding scheduling time
|
||||||
|
self.db_txn_duration_ms = 0
|
||||||
|
|
||||||
|
# ms spent waiting for db txns to be scheduled
|
||||||
|
self.db_sched_duration_ms = 0
|
||||||
|
|
||||||
self.usage_start = None
|
self.usage_start = None
|
||||||
|
self.usage_end = None
|
||||||
self.main_thread = threading.current_thread()
|
self.main_thread = threading.current_thread()
|
||||||
|
self.request = None
|
||||||
self.tag = ""
|
self.tag = ""
|
||||||
self.alive = True
|
self.alive = True
|
||||||
|
|
||||||
@ -105,7 +120,11 @@ class LoggingContext(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def current_context(cls):
|
def current_context(cls):
|
||||||
"""Get the current logging context from thread local storage"""
|
"""Get the current logging context from thread local storage
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoggingContext: the current logging context
|
||||||
|
"""
|
||||||
return getattr(cls.thread_local, "current_context", cls.sentinel)
|
return getattr(cls.thread_local, "current_context", cls.sentinel)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -155,11 +174,13 @@ class LoggingContext(object):
|
|||||||
self.alive = False
|
self.alive = False
|
||||||
|
|
||||||
def copy_to(self, record):
|
def copy_to(self, record):
|
||||||
"""Copy fields from this context to the record"""
|
"""Copy logging fields from this context to a log record or
|
||||||
for key, value in self.__dict__.items():
|
another LoggingContext
|
||||||
setattr(record, key, value)
|
"""
|
||||||
|
|
||||||
record.ru_utime, record.ru_stime = self.get_resource_usage()
|
# 'request' is the only field we currently use in the logger, so that's
|
||||||
|
# all we need to copy
|
||||||
|
record.request = self.request
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
if threading.current_thread() is not self.main_thread:
|
if threading.current_thread() is not self.main_thread:
|
||||||
@ -194,7 +215,16 @@ class LoggingContext(object):
|
|||||||
|
|
||||||
def add_database_transaction(self, duration_ms):
|
def add_database_transaction(self, duration_ms):
|
||||||
self.db_txn_count += 1
|
self.db_txn_count += 1
|
||||||
self.db_txn_duration += duration_ms / 1000.
|
self.db_txn_duration_ms += duration_ms
|
||||||
|
|
||||||
|
def add_database_scheduled(self, sched_ms):
|
||||||
|
"""Record a use of the database pool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sched_ms (int): number of milliseconds it took us to get a
|
||||||
|
connection
|
||||||
|
"""
|
||||||
|
self.db_sched_duration_ms += sched_ms
|
||||||
|
|
||||||
|
|
||||||
class LoggingContextFilter(logging.Filter):
|
class LoggingContextFilter(logging.Filter):
|
||||||
|
@ -27,25 +27,62 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||||
|
|
||||||
block_timer = metrics.register_distribution(
|
# total number of times we have hit this block
|
||||||
"block_timer",
|
block_counter = metrics.register_counter(
|
||||||
labels=["block_name"]
|
"block_count",
|
||||||
|
labels=["block_name"],
|
||||||
|
alternative_names=(
|
||||||
|
# the following are all deprecated aliases for the same metric
|
||||||
|
metrics.name_prefix + x for x in (
|
||||||
|
"_block_timer:count",
|
||||||
|
"_block_ru_utime:count",
|
||||||
|
"_block_ru_stime:count",
|
||||||
|
"_block_db_txn_count:count",
|
||||||
|
"_block_db_txn_duration:count",
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
block_ru_utime = metrics.register_distribution(
|
block_timer = metrics.register_counter(
|
||||||
"block_ru_utime", labels=["block_name"]
|
"block_time_seconds",
|
||||||
|
labels=["block_name"],
|
||||||
|
alternative_names=(
|
||||||
|
metrics.name_prefix + "_block_timer:total",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
block_ru_stime = metrics.register_distribution(
|
block_ru_utime = metrics.register_counter(
|
||||||
"block_ru_stime", labels=["block_name"]
|
"block_ru_utime_seconds", labels=["block_name"],
|
||||||
|
alternative_names=(
|
||||||
|
metrics.name_prefix + "_block_ru_utime:total",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
block_db_txn_count = metrics.register_distribution(
|
block_ru_stime = metrics.register_counter(
|
||||||
"block_db_txn_count", labels=["block_name"]
|
"block_ru_stime_seconds", labels=["block_name"],
|
||||||
|
alternative_names=(
|
||||||
|
metrics.name_prefix + "_block_ru_stime:total",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
block_db_txn_duration = metrics.register_distribution(
|
block_db_txn_count = metrics.register_counter(
|
||||||
"block_db_txn_duration", labels=["block_name"]
|
"block_db_txn_count", labels=["block_name"],
|
||||||
|
alternative_names=(
|
||||||
|
metrics.name_prefix + "_block_db_txn_count:total",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# seconds spent waiting for db txns, excluding scheduling time, in this block
|
||||||
|
block_db_txn_duration = metrics.register_counter(
|
||||||
|
"block_db_txn_duration_seconds", labels=["block_name"],
|
||||||
|
alternative_names=(
|
||||||
|
metrics.name_prefix + "_block_db_txn_duration:total",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# seconds spent waiting for a db connection, in this block
|
||||||
|
block_db_sched_duration = metrics.register_counter(
|
||||||
|
"block_db_sched_duration_seconds", labels=["block_name"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -64,7 +101,9 @@ def measure_func(name):
|
|||||||
class Measure(object):
|
class Measure(object):
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
"clock", "name", "start_context", "start", "new_context", "ru_utime",
|
"clock", "name", "start_context", "start", "new_context", "ru_utime",
|
||||||
"ru_stime", "db_txn_count", "db_txn_duration", "created_context"
|
"ru_stime",
|
||||||
|
"db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
|
||||||
|
"created_context",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, clock, name):
|
def __init__(self, clock, name):
|
||||||
@ -84,13 +123,16 @@ class Measure(object):
|
|||||||
|
|
||||||
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
|
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
|
||||||
self.db_txn_count = self.start_context.db_txn_count
|
self.db_txn_count = self.start_context.db_txn_count
|
||||||
self.db_txn_duration = self.start_context.db_txn_duration
|
self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
|
||||||
|
self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
if isinstance(exc_type, Exception) or not self.start_context:
|
if isinstance(exc_type, Exception) or not self.start_context:
|
||||||
return
|
return
|
||||||
|
|
||||||
duration = self.clock.time_msec() - self.start
|
duration = self.clock.time_msec() - self.start
|
||||||
|
|
||||||
|
block_counter.inc(self.name)
|
||||||
block_timer.inc_by(duration, self.name)
|
block_timer.inc_by(duration, self.name)
|
||||||
|
|
||||||
context = LoggingContext.current_context()
|
context = LoggingContext.current_context()
|
||||||
@ -114,7 +156,12 @@ class Measure(object):
|
|||||||
context.db_txn_count - self.db_txn_count, self.name
|
context.db_txn_count - self.db_txn_count, self.name
|
||||||
)
|
)
|
||||||
block_db_txn_duration.inc_by(
|
block_db_txn_duration.inc_by(
|
||||||
context.db_txn_duration - self.db_txn_duration, self.name
|
(context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
|
||||||
|
self.name
|
||||||
|
)
|
||||||
|
block_db_sched_duration.inc_by(
|
||||||
|
(context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
|
||||||
|
self.name
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.created_context:
|
if self.created_context:
|
||||||
|
@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class NotRetryingDestination(Exception):
|
class NotRetryingDestination(Exception):
|
||||||
def __init__(self, retry_last_ts, retry_interval, destination):
|
def __init__(self, retry_last_ts, retry_interval, destination):
|
||||||
|
"""Raised by the limiter (and federation client) to indicate that we are
|
||||||
|
are deliberately not attempting to contact a given server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
retry_last_ts (int): the unix ts in milliseconds of our last attempt
|
||||||
|
to contact the server. 0 indicates that the last attempt was
|
||||||
|
successful or that we've never actually attempted to connect.
|
||||||
|
retry_interval (int): the time in milliseconds to wait until the next
|
||||||
|
attempt.
|
||||||
|
destination (str): the domain in question
|
||||||
|
"""
|
||||||
|
|
||||||
msg = "Not retrying server %s." % (destination,)
|
msg = "Not retrying server %s." % (destination,)
|
||||||
super(NotRetryingDestination, self).__init__(msg)
|
super(NotRetryingDestination, self).__init__(msg)
|
||||||
|
|
||||||
|
48
synapse/util/threepids.py
Normal file
48
synapse/util/threepids.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_3pid_allowed(hs, medium, address):
|
||||||
|
"""Checks whether a given format of 3PID is allowed to be used on this HS
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): server
|
||||||
|
medium (str): 3pid medium - e.g. email, msisdn
|
||||||
|
address (str): address within that medium (e.g. "wotan@matrix.org")
|
||||||
|
msisdns need to first have been canonicalised
|
||||||
|
Returns:
|
||||||
|
bool: whether the 3PID medium/address is allowed to be added to this HS
|
||||||
|
"""
|
||||||
|
|
||||||
|
if hs.config.allowed_local_3pids:
|
||||||
|
for constraint in hs.config.allowed_local_3pids:
|
||||||
|
logger.debug(
|
||||||
|
"Checking 3PID %s (%s) against %s (%s)",
|
||||||
|
address, medium, constraint['pattern'], constraint['medium'],
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
medium == constraint['medium'] and
|
||||||
|
re.match(constraint['pattern'], address)
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
@ -12,9 +12,12 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os.path
|
import os.path
|
||||||
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
@ -23,7 +26,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.dir = tempfile.mkdtemp()
|
self.dir = tempfile.mkdtemp()
|
||||||
print self.dir
|
|
||||||
self.file = os.path.join(self.dir, "homeserver.yaml")
|
self.file = os.path.join(self.dir, "homeserver.yaml")
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
@ -48,3 +50,16 @@ class ConfigGenerationTestCase(unittest.TestCase):
|
|||||||
]),
|
]),
|
||||||
set(os.listdir(self.dir))
|
set(os.listdir(self.dir))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.assert_log_filename_is(
|
||||||
|
os.path.join(self.dir, "lemurs.win.log.config"),
|
||||||
|
os.path.join(os.getcwd(), "homeserver.log"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def assert_log_filename_is(self, log_config_file, expected):
|
||||||
|
with open(log_config_file) as f:
|
||||||
|
config = f.read()
|
||||||
|
# find the 'filename' line
|
||||||
|
matches = re.findall("^\s*filename:\s*(.*)$", config, re.M)
|
||||||
|
self.assertEqual(1, len(matches))
|
||||||
|
self.assertEqual(matches[0], expected)
|
||||||
|
@ -68,7 +68,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def check_context(self, _, expected):
|
def check_context(self, _, expected):
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
getattr(LoggingContext.current_context(), "test_key", None),
|
getattr(LoggingContext.current_context(), "request", None),
|
||||||
expected
|
expected
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -82,7 +82,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
lookup_2_deferred = defer.Deferred()
|
lookup_2_deferred = defer.Deferred()
|
||||||
|
|
||||||
with LoggingContext("one") as context_one:
|
with LoggingContext("one") as context_one:
|
||||||
context_one.test_key = "one"
|
context_one.request = "one"
|
||||||
|
|
||||||
wait_1_deferred = kr.wait_for_previous_lookups(
|
wait_1_deferred = kr.wait_for_previous_lookups(
|
||||||
["server1"],
|
["server1"],
|
||||||
@ -96,7 +96,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
wait_1_deferred.addBoth(self.check_context, "one")
|
wait_1_deferred.addBoth(self.check_context, "one")
|
||||||
|
|
||||||
with LoggingContext("two") as context_two:
|
with LoggingContext("two") as context_two:
|
||||||
context_two.test_key = "two"
|
context_two.request = "two"
|
||||||
|
|
||||||
# set off another wait. It should block because the first lookup
|
# set off another wait. It should block because the first lookup
|
||||||
# hasn't yet completed.
|
# hasn't yet completed.
|
||||||
@ -137,7 +137,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_perspectives(**kwargs):
|
def get_perspectives(**kwargs):
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
LoggingContext.current_context().test_key, "11",
|
LoggingContext.current_context().request, "11",
|
||||||
)
|
)
|
||||||
with logcontext.PreserveLoggingContext():
|
with logcontext.PreserveLoggingContext():
|
||||||
yield persp_deferred
|
yield persp_deferred
|
||||||
@ -145,7 +145,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
self.http_client.post_json.side_effect = get_perspectives
|
self.http_client.post_json.side_effect = get_perspectives
|
||||||
|
|
||||||
with LoggingContext("11") as context_11:
|
with LoggingContext("11") as context_11:
|
||||||
context_11.test_key = "11"
|
context_11.request = "11"
|
||||||
|
|
||||||
# start off a first set of lookups
|
# start off a first set of lookups
|
||||||
res_deferreds = kr.verify_json_objects_for_server(
|
res_deferreds = kr.verify_json_objects_for_server(
|
||||||
@ -167,13 +167,13 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
# wait a tick for it to send the request to the perspectives server
|
# wait a tick for it to send the request to the perspectives server
|
||||||
# (it first tries the datastore)
|
# (it first tries the datastore)
|
||||||
yield async.sleep(0.005)
|
yield async.sleep(1) # XXX find out why this takes so long!
|
||||||
self.http_client.post_json.assert_called_once()
|
self.http_client.post_json.assert_called_once()
|
||||||
|
|
||||||
self.assertIs(LoggingContext.current_context(), context_11)
|
self.assertIs(LoggingContext.current_context(), context_11)
|
||||||
|
|
||||||
context_12 = LoggingContext("12")
|
context_12 = LoggingContext("12")
|
||||||
context_12.test_key = "12"
|
context_12.request = "12"
|
||||||
with logcontext.PreserveLoggingContext(context_12):
|
with logcontext.PreserveLoggingContext(context_12):
|
||||||
# a second request for a server with outstanding requests
|
# a second request for a server with outstanding requests
|
||||||
# should block rather than start a second call
|
# should block rather than start a second call
|
||||||
@ -183,7 +183,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
res_deferreds_2 = kr.verify_json_objects_for_server(
|
res_deferreds_2 = kr.verify_json_objects_for_server(
|
||||||
[("server10", json1)],
|
[("server10", json1)],
|
||||||
)
|
)
|
||||||
yield async.sleep(0.005)
|
yield async.sleep(01)
|
||||||
self.http_client.post_json.assert_not_called()
|
self.http_client.post_json.assert_not_called()
|
||||||
res_deferreds_2[0].addBoth(self.check_context, None)
|
res_deferreds_2[0].addBoth(self.check_context, None)
|
||||||
|
|
||||||
@ -211,7 +211,7 @@ class KeyringTestCase(unittest.TestCase):
|
|||||||
sentinel_context = LoggingContext.current_context()
|
sentinel_context = LoggingContext.current_context()
|
||||||
|
|
||||||
with LoggingContext("one") as context_one:
|
with LoggingContext("one") as context_one:
|
||||||
context_one.test_key = "one"
|
context_one.request = "one"
|
||||||
|
|
||||||
defer = kr.verify_json_for_server("server9", {})
|
defer = kr.verify_json_for_server("server9", {})
|
||||||
try:
|
try:
|
||||||
|
@ -143,7 +143,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||||||
except errors.SynapseError:
|
except errors.SynapseError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.DEBUG
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_claim_one_time_key(self):
|
def test_claim_one_time_key(self):
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
|
@ -15,6 +15,8 @@
|
|||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
|
||||||
from mock import Mock, NonCallableMock
|
from mock import Mock, NonCallableMock
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||||
@ -41,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
|
|||||||
self.event_id = 0
|
self.event_id = 0
|
||||||
|
|
||||||
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
||||||
listener = reactor.listenUNIX("\0xxx", server_factory)
|
# XXX: mktemp is unsafe and should never be used. but we're just a test.
|
||||||
|
path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
|
||||||
|
listener = reactor.listenUNIX(path, server_factory)
|
||||||
self.addCleanup(listener.stopListening)
|
self.addCleanup(listener.stopListening)
|
||||||
self.streamer = server_factory.streamer
|
self.streamer = server_factory.streamer
|
||||||
|
|
||||||
@ -49,7 +53,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
|
|||||||
client_factory = ReplicationClientFactory(
|
client_factory = ReplicationClientFactory(
|
||||||
self.hs, "client_name", self.replication_handler
|
self.hs, "client_name", self.replication_handler
|
||||||
)
|
)
|
||||||
client_connector = reactor.connectUNIX("\0xxx", client_factory)
|
client_connector = reactor.connectUNIX(path, client_factory)
|
||||||
self.addCleanup(client_factory.stopTrying)
|
self.addCleanup(client_factory.stopTrying)
|
||||||
self.addCleanup(client_connector.disconnect)
|
self.addCleanup(client_connector.disconnect)
|
||||||
|
|
||||||
|
@ -515,9 +515,6 @@ class RoomsCreateTestCase(RestTestCase):
|
|||||||
|
|
||||||
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
|
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_post_room_no_keys(self):
|
def test_post_room_no_keys(self):
|
||||||
# POST with no config keys, expect new room id
|
# POST with no config keys, expect new room id
|
||||||
|
@ -49,6 +49,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||||
self.hs.get_device_handler = Mock(return_value=self.device_handler)
|
self.hs.get_device_handler = Mock(return_value=self.device_handler)
|
||||||
self.hs.config.enable_registration = True
|
self.hs.config.enable_registration = True
|
||||||
|
self.hs.config.registrations_require_3pid = []
|
||||||
self.hs.config.auto_join_rooms = []
|
self.hs.config.auto_join_rooms = []
|
||||||
|
|
||||||
# init the thing we're testing
|
# init the thing we're testing
|
||||||
|
88
tests/storage/test_user_directory.py
Normal file
88
tests/storage/test_user_directory.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.storage import UserDirectoryStore
|
||||||
|
from synapse.storage.roommember import ProfileInfo
|
||||||
|
from tests import unittest
|
||||||
|
from tests.utils import setup_test_homeserver
|
||||||
|
|
||||||
|
ALICE = "@alice:a"
|
||||||
|
BOB = "@bob:b"
|
||||||
|
BOBBY = "@bobby:a"
|
||||||
|
|
||||||
|
|
||||||
|
class UserDirectoryStoreTestCase(unittest.TestCase):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def setUp(self):
|
||||||
|
self.hs = yield setup_test_homeserver()
|
||||||
|
self.store = UserDirectoryStore(None, self.hs)
|
||||||
|
|
||||||
|
# alice and bob are both in !room_id. bobby is not but shares
|
||||||
|
# a homeserver with alice.
|
||||||
|
yield self.store.add_profiles_to_user_dir(
|
||||||
|
"!room:id",
|
||||||
|
{
|
||||||
|
ALICE: ProfileInfo(None, "alice"),
|
||||||
|
BOB: ProfileInfo(None, "bob"),
|
||||||
|
BOBBY: ProfileInfo(None, "bobby")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield self.store.add_users_to_public_room(
|
||||||
|
"!room:id",
|
||||||
|
[ALICE, BOB],
|
||||||
|
)
|
||||||
|
yield self.store.add_users_who_share_room(
|
||||||
|
"!room:id",
|
||||||
|
False,
|
||||||
|
(
|
||||||
|
(ALICE, BOB),
|
||||||
|
(BOB, ALICE),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_search_user_dir(self):
|
||||||
|
# normally when alice searches the directory she should just find
|
||||||
|
# bob because bobby doesn't share a room with her.
|
||||||
|
r = yield self.store.search_user_dir(ALICE, "bob", 10)
|
||||||
|
self.assertFalse(r["limited"])
|
||||||
|
self.assertEqual(1, len(r["results"]))
|
||||||
|
self.assertDictEqual(r["results"][0], {
|
||||||
|
"user_id": BOB,
|
||||||
|
"display_name": "bob",
|
||||||
|
"avatar_url": None,
|
||||||
|
})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_search_user_dir_all_users(self):
|
||||||
|
self.hs.config.user_directory_search_all_users = True
|
||||||
|
try:
|
||||||
|
r = yield self.store.search_user_dir(ALICE, "bob", 10)
|
||||||
|
self.assertFalse(r["limited"])
|
||||||
|
self.assertEqual(2, len(r["results"]))
|
||||||
|
self.assertDictEqual(r["results"][0], {
|
||||||
|
"user_id": BOB,
|
||||||
|
"display_name": "bob",
|
||||||
|
"avatar_url": None,
|
||||||
|
})
|
||||||
|
self.assertDictEqual(r["results"][1], {
|
||||||
|
"user_id": BOBBY,
|
||||||
|
"display_name": "bobby",
|
||||||
|
"avatar_url": None,
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
self.hs.config.user_directory_search_all_users = False
|
@ -12,7 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import twisted
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -65,6 +65,10 @@ class TestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@around(self)
|
@around(self)
|
||||||
def setUp(orig):
|
def setUp(orig):
|
||||||
|
# enable debugging of delayed calls - this means that we get a
|
||||||
|
# traceback when a unit test exits leaving things on the reactor.
|
||||||
|
twisted.internet.base.DelayedCall.debug = True
|
||||||
|
|
||||||
old_level = logging.getLogger().level
|
old_level = logging.getLogger().level
|
||||||
|
|
||||||
if old_level != level:
|
if old_level != level:
|
||||||
|
176
tests/util/test_file_consumer.py
Normal file
176
tests/util/test_file_consumer.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from twisted.internet import defer, reactor
|
||||||
|
from mock import NonCallableMock
|
||||||
|
|
||||||
|
from synapse.util.file_consumer import BackgroundFileConsumer
|
||||||
|
|
||||||
|
from tests import unittest
|
||||||
|
from StringIO import StringIO
|
||||||
|
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
class FileConsumerTests(unittest.TestCase):
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_pull_consumer(self):
|
||||||
|
string_file = StringIO()
|
||||||
|
consumer = BackgroundFileConsumer(string_file)
|
||||||
|
|
||||||
|
try:
|
||||||
|
producer = DummyPullProducer()
|
||||||
|
|
||||||
|
yield producer.register_with_consumer(consumer)
|
||||||
|
|
||||||
|
yield producer.write_and_wait("Foo")
|
||||||
|
|
||||||
|
self.assertEqual(string_file.getvalue(), "Foo")
|
||||||
|
|
||||||
|
yield producer.write_and_wait("Bar")
|
||||||
|
|
||||||
|
self.assertEqual(string_file.getvalue(), "FooBar")
|
||||||
|
finally:
|
||||||
|
consumer.unregisterProducer()
|
||||||
|
|
||||||
|
yield consumer.wait()
|
||||||
|
|
||||||
|
self.assertTrue(string_file.closed)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_push_consumer(self):
|
||||||
|
string_file = BlockingStringWrite()
|
||||||
|
consumer = BackgroundFileConsumer(string_file)
|
||||||
|
|
||||||
|
try:
|
||||||
|
producer = NonCallableMock(spec_set=[])
|
||||||
|
|
||||||
|
consumer.registerProducer(producer, True)
|
||||||
|
|
||||||
|
consumer.write("Foo")
|
||||||
|
yield string_file.wait_for_n_writes(1)
|
||||||
|
|
||||||
|
self.assertEqual(string_file.buffer, "Foo")
|
||||||
|
|
||||||
|
consumer.write("Bar")
|
||||||
|
yield string_file.wait_for_n_writes(2)
|
||||||
|
|
||||||
|
self.assertEqual(string_file.buffer, "FooBar")
|
||||||
|
finally:
|
||||||
|
consumer.unregisterProducer()
|
||||||
|
|
||||||
|
yield consumer.wait()
|
||||||
|
|
||||||
|
self.assertTrue(string_file.closed)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_push_producer_feedback(self):
|
||||||
|
string_file = BlockingStringWrite()
|
||||||
|
consumer = BackgroundFileConsumer(string_file)
|
||||||
|
|
||||||
|
try:
|
||||||
|
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
|
||||||
|
|
||||||
|
resume_deferred = defer.Deferred()
|
||||||
|
producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None)
|
||||||
|
|
||||||
|
consumer.registerProducer(producer, True)
|
||||||
|
|
||||||
|
number_writes = 0
|
||||||
|
with string_file.write_lock:
|
||||||
|
for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
|
||||||
|
consumer.write("Foo")
|
||||||
|
number_writes += 1
|
||||||
|
|
||||||
|
producer.pauseProducing.assert_called_once()
|
||||||
|
|
||||||
|
yield string_file.wait_for_n_writes(number_writes)
|
||||||
|
|
||||||
|
yield resume_deferred
|
||||||
|
producer.resumeProducing.assert_called_once()
|
||||||
|
finally:
|
||||||
|
consumer.unregisterProducer()
|
||||||
|
|
||||||
|
yield consumer.wait()
|
||||||
|
|
||||||
|
self.assertTrue(string_file.closed)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyPullProducer(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.consumer = None
|
||||||
|
self.deferred = defer.Deferred()
|
||||||
|
|
||||||
|
def resumeProducing(self):
|
||||||
|
d = self.deferred
|
||||||
|
self.deferred = defer.Deferred()
|
||||||
|
d.callback(None)
|
||||||
|
|
||||||
|
def write_and_wait(self, bytes):
|
||||||
|
d = self.deferred
|
||||||
|
self.consumer.write(bytes)
|
||||||
|
return d
|
||||||
|
|
||||||
|
def register_with_consumer(self, consumer):
|
||||||
|
d = self.deferred
|
||||||
|
self.consumer = consumer
|
||||||
|
self.consumer.registerProducer(self, False)
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
class BlockingStringWrite(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.buffer = ""
|
||||||
|
self.closed = False
|
||||||
|
self.write_lock = threading.Lock()
|
||||||
|
|
||||||
|
self._notify_write_deferred = None
|
||||||
|
self._number_of_writes = 0
|
||||||
|
|
||||||
|
def write(self, bytes):
|
||||||
|
with self.write_lock:
|
||||||
|
self.buffer += bytes
|
||||||
|
self._number_of_writes += 1
|
||||||
|
|
||||||
|
reactor.callFromThread(self._notify_write)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
def _notify_write(self):
|
||||||
|
"Called by write to indicate a write happened"
|
||||||
|
with self.write_lock:
|
||||||
|
if not self._notify_write_deferred:
|
||||||
|
return
|
||||||
|
d = self._notify_write_deferred
|
||||||
|
self._notify_write_deferred = None
|
||||||
|
d.callback(None)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def wait_for_n_writes(self, n):
|
||||||
|
"Wait for n writes to have happened"
|
||||||
|
while True:
|
||||||
|
with self.write_lock:
|
||||||
|
if n <= self._number_of_writes:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self._notify_write_deferred:
|
||||||
|
self._notify_write_deferred = defer.Deferred()
|
||||||
|
|
||||||
|
d = self._notify_write_deferred
|
||||||
|
|
||||||
|
yield d
|
@ -12,12 +12,12 @@ class LoggingContextTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def _check_test_key(self, value):
|
def _check_test_key(self, value):
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
LoggingContext.current_context().test_key, value
|
LoggingContext.current_context().request, value
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_with_context(self):
|
def test_with_context(self):
|
||||||
with LoggingContext() as context_one:
|
with LoggingContext() as context_one:
|
||||||
context_one.test_key = "test"
|
context_one.request = "test"
|
||||||
self._check_test_key("test")
|
self._check_test_key("test")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -25,14 +25,14 @@ class LoggingContextTestCase(unittest.TestCase):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def competing_callback():
|
def competing_callback():
|
||||||
with LoggingContext() as competing_context:
|
with LoggingContext() as competing_context:
|
||||||
competing_context.test_key = "competing"
|
competing_context.request = "competing"
|
||||||
yield sleep(0)
|
yield sleep(0)
|
||||||
self._check_test_key("competing")
|
self._check_test_key("competing")
|
||||||
|
|
||||||
reactor.callLater(0, competing_callback)
|
reactor.callLater(0, competing_callback)
|
||||||
|
|
||||||
with LoggingContext() as context_one:
|
with LoggingContext() as context_one:
|
||||||
context_one.test_key = "one"
|
context_one.request = "one"
|
||||||
yield sleep(0)
|
yield sleep(0)
|
||||||
self._check_test_key("one")
|
self._check_test_key("one")
|
||||||
|
|
||||||
@ -43,14 +43,14 @@ class LoggingContextTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def cb():
|
def cb():
|
||||||
context_one.test_key = "one"
|
context_one.request = "one"
|
||||||
yield function()
|
yield function()
|
||||||
self._check_test_key("one")
|
self._check_test_key("one")
|
||||||
|
|
||||||
callback_completed[0] = True
|
callback_completed[0] = True
|
||||||
|
|
||||||
with LoggingContext() as context_one:
|
with LoggingContext() as context_one:
|
||||||
context_one.test_key = "one"
|
context_one.request = "one"
|
||||||
|
|
||||||
# fire off function, but don't wait on it.
|
# fire off function, but don't wait on it.
|
||||||
logcontext.preserve_fn(cb)()
|
logcontext.preserve_fn(cb)()
|
||||||
@ -107,7 +107,7 @@ class LoggingContextTestCase(unittest.TestCase):
|
|||||||
sentinel_context = LoggingContext.current_context()
|
sentinel_context = LoggingContext.current_context()
|
||||||
|
|
||||||
with LoggingContext() as context_one:
|
with LoggingContext() as context_one:
|
||||||
context_one.test_key = "one"
|
context_one.request = "one"
|
||||||
|
|
||||||
d1 = logcontext.make_deferred_yieldable(blocking_function())
|
d1 = logcontext.make_deferred_yieldable(blocking_function())
|
||||||
# make sure that the context was reset by make_deferred_yieldable
|
# make sure that the context was reset by make_deferred_yieldable
|
||||||
@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase):
|
|||||||
argument isn't actually a deferred"""
|
argument isn't actually a deferred"""
|
||||||
|
|
||||||
with LoggingContext() as context_one:
|
with LoggingContext() as context_one:
|
||||||
context_one.test_key = "one"
|
context_one.request = "one"
|
||||||
|
|
||||||
d1 = logcontext.make_deferred_yieldable("bum")
|
d1 = logcontext.make_deferred_yieldable("bum")
|
||||||
self._check_test_key("one")
|
self._check_test_key("one")
|
||||||
|
249
tests/utils.py
249
tests/utils.py
@ -13,27 +13,28 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.http.server import HttpServer
|
|
||||||
from synapse.api.errors import cs_error, CodeMessageException, StoreError
|
|
||||||
from synapse.api.constants import EventTypes
|
|
||||||
from synapse.storage.prepare_database import prepare_database
|
|
||||||
from synapse.storage.engines import create_engine
|
|
||||||
from synapse.server import HomeServer
|
|
||||||
from synapse.federation.transport import server
|
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
|
||||||
|
|
||||||
from synapse.util.logcontext import LoggingContext
|
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
|
||||||
from twisted.enterprise.adbapi import ConnectionPool
|
|
||||||
|
|
||||||
from collections import namedtuple
|
|
||||||
from mock import patch, Mock
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
from inspect import getcallargs
|
||||||
import urllib
|
import urllib
|
||||||
import urlparse
|
import urlparse
|
||||||
|
|
||||||
from inspect import getcallargs
|
from mock import Mock, patch
|
||||||
|
from twisted.internet import defer, reactor
|
||||||
|
|
||||||
|
from synapse.api.errors import CodeMessageException, cs_error
|
||||||
|
from synapse.federation.transport import server
|
||||||
|
from synapse.http.server import HttpServer
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage import PostgresEngine
|
||||||
|
from synapse.storage.engines import create_engine
|
||||||
|
from synapse.storage.prepare_database import prepare_database
|
||||||
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
|
|
||||||
|
# set this to True to run the tests against postgres instead of sqlite.
|
||||||
|
# It requires you to have a local postgres database called synapse_test, within
|
||||||
|
# which ALL TABLES WILL BE DROPPED
|
||||||
|
USE_POSTGRES_FOR_TESTS = False
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -57,32 +58,70 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
|||||||
config.worker_app = None
|
config.worker_app = None
|
||||||
config.email_enable_notifs = False
|
config.email_enable_notifs = False
|
||||||
config.block_non_admin_invites = False
|
config.block_non_admin_invites = False
|
||||||
|
config.federation_domain_whitelist = None
|
||||||
|
config.user_directory_search_all_users = False
|
||||||
|
|
||||||
|
# disable user directory updates, because they get done in the
|
||||||
|
# background, which upsets the test runner.
|
||||||
|
config.update_user_directory = False
|
||||||
|
|
||||||
config.use_frozen_dicts = True
|
config.use_frozen_dicts = True
|
||||||
config.database_config = {"name": "sqlite3"}
|
|
||||||
config.ldap_enabled = False
|
config.ldap_enabled = False
|
||||||
|
|
||||||
if "clock" not in kargs:
|
if "clock" not in kargs:
|
||||||
kargs["clock"] = MockClock()
|
kargs["clock"] = MockClock()
|
||||||
|
|
||||||
|
if USE_POSTGRES_FOR_TESTS:
|
||||||
|
config.database_config = {
|
||||||
|
"name": "psycopg2",
|
||||||
|
"args": {
|
||||||
|
"database": "synapse_test",
|
||||||
|
"cp_min": 1,
|
||||||
|
"cp_max": 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config.database_config = {
|
||||||
|
"name": "sqlite3",
|
||||||
|
"args": {
|
||||||
|
"database": ":memory:",
|
||||||
|
"cp_min": 1,
|
||||||
|
"cp_max": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
db_engine = create_engine(config.database_config)
|
||||||
|
|
||||||
|
# we need to configure the connection pool to run the on_new_connection
|
||||||
|
# function, so that we can test code that uses custom sqlite functions
|
||||||
|
# (like rank).
|
||||||
|
config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
|
||||||
|
|
||||||
if datastore is None:
|
if datastore is None:
|
||||||
db_pool = SQLiteMemoryDbPool()
|
|
||||||
yield db_pool.prepare()
|
|
||||||
hs = HomeServer(
|
hs = HomeServer(
|
||||||
name, db_pool=db_pool, config=config,
|
name, config=config,
|
||||||
|
db_config=config.database_config,
|
||||||
version_string="Synapse/tests",
|
version_string="Synapse/tests",
|
||||||
database_engine=create_engine(config.database_config),
|
database_engine=db_engine,
|
||||||
get_db_conn=db_pool.get_db_conn,
|
|
||||||
room_list_handler=object(),
|
room_list_handler=object(),
|
||||||
tls_server_context_factory=Mock(),
|
tls_server_context_factory=Mock(),
|
||||||
**kargs
|
**kargs
|
||||||
)
|
)
|
||||||
|
db_conn = hs.get_db_conn()
|
||||||
|
# make sure that the database is empty
|
||||||
|
if isinstance(db_engine, PostgresEngine):
|
||||||
|
cur = db_conn.cursor()
|
||||||
|
cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
|
||||||
|
rows = cur.fetchall()
|
||||||
|
for r in rows:
|
||||||
|
cur.execute("DROP TABLE %s CASCADE" % r[0])
|
||||||
|
yield prepare_database(db_conn, db_engine, config)
|
||||||
hs.setup()
|
hs.setup()
|
||||||
else:
|
else:
|
||||||
hs = HomeServer(
|
hs = HomeServer(
|
||||||
name, db_pool=None, datastore=datastore, config=config,
|
name, db_pool=None, datastore=datastore, config=config,
|
||||||
version_string="Synapse/tests",
|
version_string="Synapse/tests",
|
||||||
database_engine=create_engine(config.database_config),
|
database_engine=db_engine,
|
||||||
room_list_handler=object(),
|
room_list_handler=object(),
|
||||||
tls_server_context_factory=Mock(),
|
tls_server_context_factory=Mock(),
|
||||||
**kargs
|
**kargs
|
||||||
@ -301,168 +340,6 @@ class MockClock(object):
|
|||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
class SQLiteMemoryDbPool(ConnectionPool, object):
|
|
||||||
def __init__(self):
|
|
||||||
super(SQLiteMemoryDbPool, self).__init__(
|
|
||||||
"sqlite3", ":memory:",
|
|
||||||
cp_min=1,
|
|
||||||
cp_max=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.config = Mock()
|
|
||||||
self.config.password_providers = []
|
|
||||||
self.config.database_config = {"name": "sqlite3"}
|
|
||||||
|
|
||||||
def prepare(self):
|
|
||||||
engine = self.create_engine()
|
|
||||||
return self.runWithConnection(
|
|
||||||
lambda conn: prepare_database(conn, engine, self.config)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_db_conn(self):
|
|
||||||
conn = self.connect()
|
|
||||||
engine = self.create_engine()
|
|
||||||
prepare_database(conn, engine, self.config)
|
|
||||||
return conn
|
|
||||||
|
|
||||||
def create_engine(self):
|
|
||||||
return create_engine(self.config.database_config)
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryDataStore(object):
|
|
||||||
|
|
||||||
Room = namedtuple(
|
|
||||||
"Room",
|
|
||||||
["room_id", "is_public", "creator"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.tokens_to_users = {}
|
|
||||||
self.paths_to_content = {}
|
|
||||||
|
|
||||||
self.members = {}
|
|
||||||
self.rooms = {}
|
|
||||||
|
|
||||||
self.current_state = {}
|
|
||||||
self.events = []
|
|
||||||
|
|
||||||
class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")):
|
|
||||||
def fill_out_prev_events(self, event):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
|
|
||||||
return self.Snapshot(
|
|
||||||
room_id, user_id, self.get_room_member(user_id, room_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
def register(self, user_id, token, password_hash):
|
|
||||||
if user_id in self.tokens_to_users.values():
|
|
||||||
raise StoreError(400, "User in use.")
|
|
||||||
self.tokens_to_users[token] = user_id
|
|
||||||
|
|
||||||
def get_user_by_access_token(self, token):
|
|
||||||
try:
|
|
||||||
return {
|
|
||||||
"name": self.tokens_to_users[token],
|
|
||||||
}
|
|
||||||
except Exception:
|
|
||||||
raise StoreError(400, "User does not exist.")
|
|
||||||
|
|
||||||
def get_room(self, room_id):
|
|
||||||
try:
|
|
||||||
return self.rooms[room_id]
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def store_room(self, room_id, room_creator_user_id, is_public):
|
|
||||||
if room_id in self.rooms:
|
|
||||||
raise StoreError(409, "Conflicting room!")
|
|
||||||
|
|
||||||
room = MemoryDataStore.Room(
|
|
||||||
room_id=room_id,
|
|
||||||
is_public=is_public,
|
|
||||||
creator=room_creator_user_id
|
|
||||||
)
|
|
||||||
self.rooms[room_id] = room
|
|
||||||
|
|
||||||
def get_room_member(self, user_id, room_id):
|
|
||||||
return self.members.get(room_id, {}).get(user_id)
|
|
||||||
|
|
||||||
def get_room_members(self, room_id, membership=None):
|
|
||||||
if membership:
|
|
||||||
return [
|
|
||||||
v for k, v in self.members.get(room_id, {}).items()
|
|
||||||
if v.membership == membership
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
return self.members.get(room_id, {}).values()
|
|
||||||
|
|
||||||
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
|
|
||||||
return [
|
|
||||||
m[user_id] for m in self.members.values()
|
|
||||||
if user_id in m and m[user_id].membership in membership_list
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
|
|
||||||
limit=0, with_feedback=False):
|
|
||||||
return ([], from_key) # TODO
|
|
||||||
|
|
||||||
def get_joined_hosts_for_room(self, room_id):
|
|
||||||
return defer.succeed([])
|
|
||||||
|
|
||||||
def persist_event(self, event):
|
|
||||||
if event.type == EventTypes.Member:
|
|
||||||
room_id = event.room_id
|
|
||||||
user = event.state_key
|
|
||||||
self.members.setdefault(room_id, {})[user] = event
|
|
||||||
|
|
||||||
if hasattr(event, "state_key"):
|
|
||||||
key = (event.room_id, event.type, event.state_key)
|
|
||||||
self.current_state[key] = event
|
|
||||||
|
|
||||||
self.events.append(event)
|
|
||||||
|
|
||||||
def get_current_state(self, room_id, event_type=None, state_key=""):
|
|
||||||
if event_type:
|
|
||||||
key = (room_id, event_type, state_key)
|
|
||||||
if self.current_state.get(key):
|
|
||||||
return [self.current_state.get(key)]
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return [
|
|
||||||
e for e in self.current_state
|
|
||||||
if e[0] == room_id
|
|
||||||
]
|
|
||||||
|
|
||||||
def set_presence_state(self, user_localpart, state):
|
|
||||||
return defer.succeed({"state": 0})
|
|
||||||
|
|
||||||
def get_presence_list(self, user_localpart, accepted):
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_room_events_max_id(self):
|
|
||||||
return "s0" # TODO (erikj)
|
|
||||||
|
|
||||||
def get_send_event_level(self, room_id):
|
|
||||||
return defer.succeed(0)
|
|
||||||
|
|
||||||
def get_power_level(self, room_id, user_id):
|
|
||||||
return defer.succeed(0)
|
|
||||||
|
|
||||||
def get_add_state_level(self, room_id):
|
|
||||||
return defer.succeed(0)
|
|
||||||
|
|
||||||
def get_room_join_rule(self, room_id):
|
|
||||||
# TODO (erikj): This should be configurable
|
|
||||||
return defer.succeed("invite")
|
|
||||||
|
|
||||||
def get_ops_levels(self, room_id):
|
|
||||||
return defer.succeed((5, 5, 5))
|
|
||||||
|
|
||||||
def insert_client_ip(self, user, access_token, ip, user_agent):
|
|
||||||
return defer.succeed(None)
|
|
||||||
|
|
||||||
|
|
||||||
def _format_call(args, kwargs):
|
def _format_call(args, kwargs):
|
||||||
return ", ".join(
|
return ", ".join(
|
||||||
["%r" % (a) for a in args] +
|
["%r" % (a) for a in args] +
|
||||||
|
Loading…
Reference in New Issue
Block a user